In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class TinyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(2, 3, bias=True)
        self.fc2  = nn.Linear(3, 3, bias=True)
        self.fc3  = nn.Linear(3, 2, bias=True)
        self.relu = nn.ReLU()

        # ——— init to your sketch ———
        with torch.no_grad():
            self.fc1.weight.data = torch.tensor([
                [ 0.2, -0.2],
                [-0.2,  0.2],
                [-0.2, -0.2],
            ])
            self.fc1.bias.data.zero_()

            self.fc2.weight.data = torch.tensor([
                [ 0.5, -0.5,  0.5],
                [ 0.5,  0.5,  0.5],
                [ 0.5,  0.5, -0.5]
            ])
            self.fc2.bias.data.zero_()

            self.fc3.weight.data = torch.tensor([
                [-0.8,  0.8,  0.8],
                [ 0.8, -0.8,  0.8]
            ])
            self.fc3.bias.data.zero_()

    def forward(self, x):
        # — flatten any extra dims so we always have (batch, 2)
        if x.dim() > 2:
            x = x.view(x.size(0), -1)

        # fc1 pre
        x_fc1_pre = self.fc1(x)
        x_fc1_pre.retain_grad()

        # fc1 post
        x_fc1 = self.relu(x_fc1_pre)
        x_fc1.retain_grad()

        # fc2 pre
        x_fc2_pre = self.fc2(x_fc1)
        x_fc2_pre.retain_grad()

        # fc2 post
        x_fc2 = self.relu(x_fc2_pre)
        x_fc2.retain_grad()

        # fc3 logits
        x_fc3_pre = self.fc3(x_fc2)
        x_fc3_pre.retain_grad()

        # stash for later
        self.activations = {
            'fc1_pre':  x_fc1_pre,
            'fc1_post': x_fc1,
            'fc2_pre':  x_fc2_pre,
            'fc2_post': x_fc2,
            'fc3_pre':  x_fc3_pre,
        }

        # prints
        print("After fc1_pre:\n", x_fc1_pre.detach().numpy())
        print("After fc1_relu:\n", x_fc1.detach().numpy())
        print("After fc2_pre:\n", x_fc2_pre.detach().numpy())
        print("After fc2_relu:\n", x_fc2.detach().numpy())
        print("After fc3_pre (logits):\n", x_fc3_pre.detach().numpy())
        print("Softmax probs:\n",
              nn.functional.softmax(x_fc3_pre, dim=-1).detach().numpy())

        return x_fc3_pre

# — set up
model     = TinyNet()
criterion = nn.CrossEntropyLoss()
opt       = optim.SGD(model.parameters(), lr=1)

print("=== Starting parameters ===")
print("FC1 weights:\n", model.fc1.weight.data)
print("FC2 weights:\n", model.fc2.weight.data)
print("FC3 weights:\n", model.fc3.weight.data)
print("———————————————\n")

inp = torch.tensor([[0.3, 0.5]], dtype=torch.float32)  # shape (1,2)
target = torch.tensor([0])                             # shape (1,)

print("=== Iteration 0: forward/backward ===")
out   = model(inp)
loss  = criterion(out, target)
loss.backward()

print("\n=== Gradients of every intermediate activation ===")
for name, tensor in model.activations.items():
    print(f"{name}.grad =\n", tensor.grad)

opt.step()

print("\n=== After SGD step ===")
print("FC3 weights:\n", model.fc3.weight.data)
print("FC2 weights:\n", model.fc2.weight.data)
print("FC1 weights:\n", model.fc1.weight.data)

=== Starting parameters ===
FC1 weights:
 tensor([[ 0.2000, -0.2000],
        [-0.2000,  0.2000],
        [-0.2000, -0.2000]])
FC2 weights:
 tensor([[ 0.5000, -0.5000,  0.5000],
        [ 0.5000,  0.5000,  0.5000],
        [ 0.5000,  0.5000, -0.5000]])
FC3 weights:
 tensor([[-0.8000,  0.8000,  0.8000],
        [ 0.8000, -0.8000,  0.8000]])
———————————————

=== Iteration 0: forward/backward ===
After fc1_pre:
 [[-0.04  0.04 -0.16]]
After fc1_relu:
 [[0.   0.04 0.  ]]
After fc2_pre:
 [[-0.02  0.02  0.02]]
After fc2_relu:
 [[0.   0.02 0.02]]
After fc3_pre (logits):
 [[0.032 0.   ]]
Softmax probs:
 [[0.5079993  0.49200067]]

=== Gradients of every intermediate activation ===
fc1_pre.grad =
 tensor([[ 0.0000, -0.3936,  0.0000]])
fc1_post.grad =
 tensor([[-0.3936, -0.3936, -0.3936]])
fc2_pre.grad =
 tensor([[ 0.0000e+00, -7.8720e-01, -1.2016e-08]])
fc2_post.grad =
 tensor([[ 7.8720e-01, -7.8720e-01, -1.2016e-08]])
fc3_pre.grad =
 tensor([[-0.4920,  0.4920]])

=== After SGD step ===
FC3 weigh