In [None]:
from tinygrad import Tensor, nn
from tinygrad.nn import optim
import numpy as np
import matplotlib.pyplot as plt

# Generate data
np.random.seed(42)
q1 = np.random.uniform(-1, 1, 1000).astype(np.float32)
q2 = np.random.uniform(-1, 1, 1000).astype(np.float32)


# Target function
# X1 = q1 + q2 - 1
# X2 = q1 - q2 + 1

X1 = q1 + q2**2 - 1
X2 = q1 - q2 + 1

# Prepare data as tensors
X_train = Tensor(np.stack([q1, q2], axis=1))
y_train = Tensor(np.stack([X1, X2], axis=1))

# Define network
class TinyNet:
    def __init__(self):
        self.l1 = nn.Linear(2, 8)
        self.l2 = nn.Linear(8, 8)
        self.l3 = nn.Linear(8, 8)
        self.l4 = nn.Linear(8, 2)
    
    def __call__(self, x):
        x = self.l1(x).relu()
        x = self.l2(x).relu()
        x = self.l3(x).relu()
        x = self.l4(x)
        return x

# Initialize model
model = TinyNet()

# Get parameters for optimizer
params = [model.l1.weight, model.l1.bias, 
          model.l2.weight, model.l2.bias,
          model.l3.weight, model.l3.bias,
          model.l4.weight, model.l4.bias]

# Use tinygrad's optimizer
optimizer = optim.SGD(params, lr=0.01)

# Training loop
epochs = 1000
loss_history = []

for epoch in range(epochs):
    # Forward pass
    output = model(X_train)
    
    # Loss (MSE)
    loss = ((output - y_train) ** 2).mean()
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()

    Tensor.training = True

    optimizer.step()
    
    loss_history.append(loss.numpy())
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.numpy():.6f}")

# Plot loss
plt.figure(figsize=(10, 6))
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

# Test
test_input = Tensor([[0.5, 0.3]])
output = model(test_input)
expected = [0.5 + 0.3 - 1, 0.5 - 0.3 + 1]

print(f"\nTest: [0.5, 0.3]")
print(f"Output: {output.numpy()[0]}")
print(f"Expected: {expected}")
print(f"Error: {np.abs(output.numpy()[0] - expected)}")