In [1]:
import sys
import os
sys.path.append(os.path.abspath("..")) 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.datasets import make_moons

In [3]:
from micrograd.engine import Value
from micrograd.nn import MLP

# Dataset

In [4]:
X, y = make_moons(n_samples=100, noise=0.1)
y = y * 2 - 1  # make y be -1 or 1

# --- PyTorch Version ---
X_torch = torch.tensor(X, dtype=torch.float32)
y_torch = torch.tensor(y, dtype=torch.float32).view(-1, 1)

# Pytorch Model

In [5]:
class TorchMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

model_torch = TorchMLP()

In [7]:
scores = model_torch(X_torch)
losses = torch.relu(1 + -y_torch * scores)
data_loss = losses.mean()
reg_loss = 1e-4 * sum((p ** 2).sum() for p in model_torch.parameters())
loss_t = data_loss + reg_loss
loss_t.backward()

torch_grads = model_torch.fc1.weight.grad.detach().numpy()

# Micrograd Model

In [8]:
model_micro = MLP(2, [16, 16, 1])

def micro_loss(X, y):
    inputs = [list(map(Value, xrow)) for xrow in X]
    scores = list(map(model_micro, inputs))
    losses = [(1 + -yi * score).relu() for yi, score in zip(y, scores)]
    data_loss = sum(losses) * (1.0 / len(losses))
    reg_loss = 1e-4 * sum((p * p for p in model_micro.parameters()))
    total_loss = data_loss + reg_loss
    return total_loss

total_loss = micro_loss(X, y)
model_micro.zero_grad()
total_loss.backward()

extracting gradients

In [11]:
W0_grads = np.array([[p.grad for p in neuron.w] for neuron in model_micro.layers[0].neurons])

# Compare

In [13]:
abs_diff = np.abs(torch_grads - W0_grads)
max_abs_error = abs_diff.max()
mean_abs_error = abs_diff.mean()

print("Max absolute error:", max_abs_error)
print("Mean absolute error:", mean_abs_error)

Max absolute error: 0.34151939979542156
Mean absolute error: 0.09357949932870627
