In [1]:
import numpy as np
from torch import nn
import torch


class NET(nn.Module):
    def __init__(self, out_dim=11):
        super(NET, self).__init__()
        self.fc = nn.Linear(2, 2, bias=False)
        self.fc.weight = torch.nn.parameter.Parameter(
            torch.tensor(np.array([[1,2],[3,4]])).float())
        return

    def print(self):
        print('weights')
        print(self.fc.weight)
        print('\ngrads')
        print(self.fc.weight.grad)

    def forward(self, x, nograd=False):
        x = self.fc(x)
        if nograd:
            with torch.no_grad():
                y = self.fc(x**2)
        else:
            y = self.fc(x**2)
        return x, y


In [2]:
batch = torch.tensor(np.round(np.random.rand(4,2),2)).float()

In [3]:
# case 1: This is the base case. Only gradients from the computation with x = self.fc(x) should show up.
net1 = NET()
xr, yr = net1(batch, nograd=True)
loss = torch.sum(batch - xr)**2 + torch.sum(batch - yr)**2
loss.backward()
net1.print()

weights
Parameter containing:
tensor([[1., 2.],
        [3., 4.]], requires_grad=True)

grads
tensor([[31.4116, 48.5452],
        [31.4116, 48.5452]])


In [4]:
# case 2: Compare with base case. Because the calculation is performed with no_grad(), the yr calculation does not affect the gradient.
net2 = NET()
xr, yr = net2(batch, nograd=True)
loss = torch.sum(batch - xr)**2
loss.backward()
net2.print()

weights
Parameter containing:
tensor([[1., 2.],
        [3., 4.]], requires_grad=True)

grads
tensor([[31.4116, 48.5452],
        [31.4116, 48.5452]])


In [5]:
# case 3: Here, no_grad is not used. Compared to case 1 and 2, the gradient accumulated comes from both xr and yr calculations.
net3 = NET()
xr, yr = net3(batch, nograd=False)
loss = torch.sum(batch - xr)**2 + torch.sum(batch - yr)**2
loss.backward()
net3.print()

weights
Parameter containing:
tensor([[1., 2.],
        [3., 4.]], requires_grad=True)

grads
tensor([[10282.8691, 30702.7285],
        [26194.2266, 57988.7422]])


In [6]:
# case 4: I'm expecting gradient from yr calculation alone to explain the difference between cases 2 and 3:
net4 = NET()
xr, yr = net4(batch, nograd=False)
loss = torch.sum(batch - yr)**2
loss.backward()
net4.print()

weights
Parameter containing:
tensor([[1., 2.],
        [3., 4.]], requires_grad=True)

grads
tensor([[10251.4580, 30654.1836],
        [26162.8145, 57940.1953]])


In [7]:
net4.fc.weight.grad+net2.fc.weight.grad == net3.fc.weight.grad

tensor([[True, True],
        [True, True]])