In [17]:
import torch
from torch import nn


class BasicNeuralNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 1, bias=False)
        self.fc2 = nn.Linear(1, 2, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [18]:
torch.manual_seed(123)
model = BasicNeuralNet()
x = torch.tensor([0.5])
output = model(x)
output

tensor([-0.0068,  0.1013], grad_fn=<SqueezeBackward4>)

In [24]:
print("fc1 weight:", model.fc1.weight)
print("fc1 bias:", model.fc1.bias)
print("fc2 weight:", model.fc2.weight)
print("fc2 bias:", model.fc2.bias)

fc1 weight: Parameter containing:
tensor([[-0.4078]], requires_grad=True)
fc1 bias: None
fc2 weight: Parameter containing:
tensor([[ 0.0331],
        [-0.4967]], requires_grad=True)
fc2 bias: None


In [29]:
target = torch.tensor([1., 0.])
target

tensor([1., 0.])

In [37]:
loss = torch.nn.MSELoss()
loss = loss(output, target)
print(loss)
loss.backward()

tensor(0.5119, grad_fn=<MseLossBackward0>)


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [41]:
# Print the gradients of the weights of the first layer
print("fc1 weight grad:", model.fc1.weight.grad)

# Print the gradients of the weights of the second layer
print("fc2 weight grad:", model.fc2.weight.grad)

fc1 weight grad: tensor([[-0.0418]])
fc2 weight grad: tensor([[ 0.2053],
        [-0.0206]])


In [49]:
# MSE calculation ✅
(1/2)*((-0.0068-1)**2 + (0.1013-0)**2)

0.511953965

In [51]:
# fc2 weight0 gradient calculation ✅
-1.0068*0.5*-0.4078

0.20528651999999997

In [54]:
# fc2 weight1 gradient calculation ✅
(0.5*-0.4078)*0.1013

-0.02065507

In [55]:
# fc1 weight gradient calculation ✅
0.5*0.0331*-1.0068 + 0.5*-0.4967*0.1013

-0.041820394999999996