In [27]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image

class Net(nn.Module):
    def __init__(self, input_shape, n_actions) -> None:
        super(Net, self).__init__()
        self.conv = nn.Sequential(
               nn.Conv2d(input_shape[0], 32, kernel_size=7, stride=4),
               nn.ReLU(),
               nn.Conv2d(32, 64, kernel_size=5, stride=2),
               nn.ReLU(),
               nn.Conv2d(64, 64, kernel_size=3, stride=1),
               nn.ReLU()
        )
    
        self.fc = nn.Sequential(
            nn.Linear(1600, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def forward(self, x):
        conv_out = self.conv(x).flatten()
        return self.fc(conv_out)



In [28]:
inputs = torch.tensor(
            np.asarray(
                [np.asarray(Image.open(f'input{i+1}'))
                     for i in range(4)],
                dtype=np.float32) / 255)

In [29]:
net = Net(inputs.shape, 9)
conv_out = net.conv(inputs)
conv_out.flatten()
fc_out = net.fc(conv_out.flatten())

In [30]:
fc_out.retain_grad()

In [31]:
y_pred = torch.max(fc_out)
y_pred

tensor(0.0354, grad_fn=<MaxBackward1>)

In [32]:
loss = (torch.tensor(0.03) - y_pred) ** 2
loss

tensor(2.9314e-05, grad_fn=<PowBackward0>)

In [33]:
loss.backward()

In [34]:
fc_out.grad

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0108, 0.0000, 0.0000, 0.0000, 0.0000])

In [35]:
fc_out

tensor([-0.0452, -0.0183, -0.0334, -0.0116,  0.0354, -0.0456, -0.0242, -0.0441,
        -0.0029], grad_fn=<AddBackward0>)

In [70]:
torch.sum(net.fc[0].weight.grad == 0)


tensor(651620)