### GPT architecture part 4: Shortcut connections

In [2]:
import torch
import torch.nn as nn

class GELU(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))))

In [3]:
class ExampleDeepNeuralNetwork(nn.Module):
    def __init__(self, layer_sizes, use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5]), GELU())
        ])
    
    def forward(self, x):
        for layer in self.layers:
            # Compute the output of the current layer
            layer_output = layer(x)
            # Check if shortcut can be applied
            if self.use_shortcut and x.shape == layer_output.shape:
                x = x + layer_output
            else:
                x = layer_output
        return x

In [4]:
### This code implements a deep nn with 5 layers, each consisting of a Linear layer and a GELU activation function.
### In te forward pass, we iteratively pass the input through the layers and optionally add the shortcut connections
### if the self.use_shortcut is set to True.

In [None]:
### Let us use this code to first initialize a neural network without shortcut connections. Here, each layer will be initialized
### such 

In [5]:
layer_sizes = [3, 3, 3, 3, 3, 1]
sample_input = torch.tensor([[1., 0., -1.]])
torch.manual_seed(123)
model_without_shortcut = ExampleDeepNeuralNetwork(layer_sizes, use_shortcut=False)

In [6]:
### Next we implement a function that computes the gradients in the model's backward pass

In [7]:
def print_gradients(model, x):
    # Forward pass
    output = model(x)
    target = torch.tensor([[0.]])
    
    # Calculate loss based on how close the target and outputs are
    loss = nn.MSELoss()
    loss = loss(output, target)
    
    # Backward pass to calculate the gradients
    loss.backward()
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            # Print the mean absolute gradient of the weights
            print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")

In [8]:
### In the preceding code, we specify a loss function that computes how close the model output and a user-specified target
### Then when calling loss.backward(), pytorch computes the loss gradient for each layer in the model.
### We can iterate through the weight parameters via model.names_parameters().
### Suppose we have a 3x3 weight parameter matrix for a given layer.
### In that case, this layer will have 3x3 gradient values and we print the mean absolute gradient of these 3x3 gradient values
### to obtain a single gradient value per layer to compare the gradients between layers more easily.

In [9]:
print_gradients(model_without_shortcut, sample_input)

layers.0.0.weight has gradient mean of 0.00020173587836325169
layers.1.0.weight has gradient mean of 0.0001201116101583466
layers.2.0.weight has gradient mean of 0.0007152041653171182
layers.3.0.weight has gradient mean of 0.001398873864673078
layers.4.0.weight has gradient mean of 0.005049646366387606


In [10]:
### Lets now initiate a model with skip connections and see how it compares.

In [12]:
torch.manual_seed(123)
model_with_shortcut = ExampleDeepNeuralNetwork(layer_sizes, use_shortcut=True)
print_gradients(model_with_shortcut, sample_input)

layers.0.0.weight has gradient mean of 0.22169792652130127
layers.1.0.weight has gradient mean of 0.20694106817245483
layers.2.0.weight has gradient mean of 0.32896995544433594
layers.3.0.weight has gradient mean of 0.2665732502937317
layers.4.0.weight has gradient mean of 1.3258541822433472


In [13]:
### As we can see, based on the output, the last layer still has a larger gradient than the other layers.
### However, the gradient value stabilizes as we progress towards the first layer and does not shrink to a vanishingly small value.

In [14]:
### In conclusion, shortcut connections are important for overcoming the limitations posed by the vanishing gradient problem 
### in deep nn.
### They are a core building block of very large models such as LLMs and they will help facilitate more effective training by 
### ensuring consistent gradient flow across layers when we train the GPT model.