In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np



In [2]:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc0 = nn.Linear(3, 3)
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(4, 5)
        self.fc3 = nn.Linear(5, 1)

    def forward(self, x):
        f0 = self.fc0(x)
        h1 = F.relu(f0)
        f1 = self.fc1(h1)
        h2 = F.relu(f1)
        f2 = self.fc2(h2)
        h3 = F.relu(f2)
        f3 = self.fc3(h3)
        x = f3
        
        return x


input = torch.randn(3)
target = torch.randn(1)

net = Net()
pred = net(input)

criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

#get gradients
optimizer.zero_grad()
loss = criterion(pred, target)
loss.backward()

tensor([-2.7897,  0.4034, -1.1643]) tensor([0.8723]) tensor([0.5053], grad_fn=<ViewBackward0>)


In [3]:
#manual forward pass

layers_input = []
layers_preactivations = []
layers_outputs = []

layers_weights = []
layers_biases = []

ws = [net.fc0.weight, net.fc1.weight, net.fc2.weight, net.fc3.weight]
bs = [net.fc0.bias, net.fc1.bias, net.fc2.bias, net.fc3.bias]

input_next = input
for k in range(4):
    layers_input.append(input_next)
    w = ws[k]
    b = bs[k]
    f = torch.matmul(w, input_next) + b
    h = F.relu(f)
    input_next = h

    layers_preactivations.append(f)
    layers_outputs.append(h)

    layers_weights.append(w)
    layers_biases.append(b)

    #print(f, fs[k])
    #print(h, hs[k]) if k <3 else None

print('Net output my / pytorch:')
print(layers_outputs[-1], pred)

Net output my / pytorch:
tensor([0.5053], grad_fn=<ReluBackward0>) tensor([0.5053], grad_fn=<ViewBackward0>)


In [4]:
# manual backward pass

grads = []
grads_ws = []
grads_b2 = []

for k in [3,2,1,0]:

    if k == 3:
        grad = 2*(layers_outputs[k] - target) #mse loss derivative, d(x-y)^2
    else:
        grad = torch.matmul(layers_weights[k+1].T, grads[-1]) #dl/df(k) = omega_(k+1)^T * dl/df(k+1), grads[-1] is the previous layer's gradient
        grad = grad * (layers_preactivations[k] > 0).float() #for relu, zero out gradients where preactivation is negative

    grads.append(grad)

    grad_w = torch.matmul(grad.view(-1,1), layers_input[k].view(1,-1)) #dl/dw(k) = dl/df(k) * h(k). tensor.view(-1, 1) makes it a column vector instead of a row vector

    grad_b = grad

    grads_ws.append(grad_w)
    grads_b2.append(grad_b)


grads = grads[::-1]
grads_ws = grads_ws[::-1]
grads_b2 = grads_b2[::-1]


for k in range(4):
    print(f'===layer {k=}===')
    my_weight_grad = grads_ws[k]
    my_bias_grad = grads_b2[k]
    torch_weight_grad = layers_weights[k].grad
    torch_bias_grad = layers_biases[k].grad

    my_grad_np = my_weight_grad.detach().numpy()
    torch_grad_np = torch_weight_grad.detach().numpy()

    print('grad my/torch:')
    print(my_grad_np)
    print(torch_grad_np)

    my_grad_np = my_bias_grad.detach().numpy()
    torch_grad_np = torch_bias_grad.detach().numpy()

    print('bias my/torch::')
    print(my_grad_np)
    print(torch_grad_np)




===layer k=0===
grad my/torch:
[[ 0.04533107 -0.00655478  0.01891906]
 [ 0.13311084 -0.01924756  0.0555542 ]
 [-0.          0.         -0.        ]]
[[ 0.04533107 -0.00655478  0.01891906]
 [ 0.13311084 -0.01924756  0.0555542 ]
 [-0.          0.         -0.        ]]
bias my/torch::
[-0.01624922 -0.04771445  0.        ]
[-0.01624922 -0.04771445  0.        ]
===layer k=1===
grad my/torch:
[[-0.13762619 -0.13665056 -0.        ]
 [-0.         -0.         -0.        ]
 [-0.11894695 -0.11810374 -0.        ]
 [ 0.          0.          0.        ]]
[[-0.13762619 -0.13665056 -0.        ]
 [ 0.          0.          0.        ]
 [-0.11894695 -0.11810374 -0.        ]
 [ 0.          0.          0.        ]]
bias my/torch::
[-0.08385601 -0.         -0.0724747   0.        ]
[-0.08385601  0.         -0.0724747   0.        ]
===layer k=2===
grad my/torch:
[[-0.         -0.         -0.         -0.        ]
 [-0.         -0.         -0.         -0.        ]
 [-0.13080193 -0.         -0.05091183 -0.      