This project demonstrates the effect of residual (shortcut) connections on gradient flow in deep fully connected neural networks using PyTorch. It compares the magnitude of gradients in each layer with and without residual connections after a single backward pass.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NeuralNetwork(nn.Module):
    def __init__(self, shortcut_connections = False):
        super().__init__()
        self.shortcut_connections = shortcut_connections  # store in self
        self.layers = nn.ModuleList([
            nn.Linear(10, 10, bias=False),
            nn.Linear(10, 10, bias=False),
            nn.Linear(10, 10, bias=False),
            nn.Linear(10, 10, bias=False),
            nn.Linear(10, 1, bias=False)
        ])

    def forward(self, x):
        if self.shortcut_connections:
            for i, layer in enumerate(self.layers):
                if i < len(self.layers) - 1:
                    residual = x
                    x = F.relu(layer(x))
                    x = x + residual
                else:
                    x = layer(x)
            return x
        # if no shortcut_connections is false
        else:
            # Apply ReLU only on all layers except the last one
            for i, layer in enumerate(self.layers):
                if i < len(self.layers) - 1:
                    x = F.relu(layer(x))
                else:
                    x = layer(x)
            return x


def output(x):
    model = NeuralNetwork(shortcut_connections = True)
    out = model(x)
    targets = torch.tensor([[0.]])
    criterion = nn.MSELoss()
    loss = criterion(out, targets)
    loss.backward()
    for name, param in model.named_parameters():
        if "weight" in name:
            print(f"{name} has an average gradient of: {param.grad.abs().mean().item()}")

In [40]:
torch.manual_seed(123)
sample_input = torch.randn(1, 10)
output(sample_input)

layers.0.weight has an average gradient of: 0.010994031094014645
layers.1.weight has an average gradient of: 0.006011884659528732
layers.2.weight has an average gradient of: 0.011236266233026981
layers.3.weight has an average gradient of: 0.008820828050374985
layers.4.weight has an average gradient of: 0.10310874879360199
