In [6]:
import torch
import torch.nn as nn

In [7]:
class DNN(nn.Module):
    def __init__(self,layer_sizes,use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(layer_sizes[i],layer_sizes[i+1]),nn.GELU())
            for i in range(len(layer_sizes)-1)
        ])
        
    def forward(self,x):
        for layer in self.layers:
            layer_output = layer(x)
            if self.use_shortcut and layer_output.shape == x.shape:
                x = x + layer_output
            else:
                x = layer_output
        return x

            

In [12]:
layer_sizes = [3,3,3,3,3,1]

dnn = DNN(layer_sizes,use_shortcut=False)

torch.manual_seed(123)

x = torch.tensor([[1.,0.,1.]])

In [13]:
def print_gradients(model,x):

    output = model(x)
    target = torch.tensor([[0.]])

    loss = nn.MSELoss()(output,target)

    loss.backward()

    for name,param in model.named_parameters():
        if 'weight' in name:
            print(f'{name} has gradient mean {param.grad.abs().mean().item()}')

In [14]:
print_gradients(dnn,x)

layers.0.0.weight has gradient mean 3.8826343370601535e-05
layers.1.0.weight has gradient mean 6.894793477840722e-05
layers.2.0.weight has gradient mean 8.831745071802288e-05
layers.3.0.weight has gradient mean 0.0011646426282823086
layers.4.0.weight has gradient mean 0.006721900776028633


In [15]:
## now add skip connection

dnn = DNN(layer_sizes,use_shortcut=True)

print_gradients(dnn,x)


layers.0.0.weight has gradient mean 0.23126482963562012
layers.1.0.weight has gradient mean 0.2373097687959671
layers.2.0.weight has gradient mean 0.3484981656074524
layers.3.0.weight has gradient mean 0.1335914433002472
layers.4.0.weight has gradient mean 1.823801040649414


In [None]:
## So we see that the gradient is not vanishing for the skip connection model.