In [1]:
from importlib.metadata import version
import torch

print("TORCH VERSION :", version("torch"))
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backend.mps.is_available() else 'cpu'
print('GPU  : ', device)

TORCH VERSION : 2.2.1
GPU  :  cuda


In [2]:
import torch.nn as nn

In [3]:
class ResidualConnectionLN(nn.Module):
    def __init__(self,normalized_shape:int,dropout:float=0.1):
        super().__init__()
        
        self.norm = nn.LayerNorm(normalized_shape)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
        

In [4]:
## Short Circuit

In [70]:

class DummyNW(nn.Module):
    def __init__(self,dim:torch.Tensor,target=torch.tensor([0.]),short_circuit=True):
        super().__init__()
        self.short_circuit = short_circuit
        self.layers = nn.ModuleList([nn.Sequential(nn.Linear(dim[i],dim[i+1]),nn.GELU()) for i in range(len(dim)-1)])

    def forward(self,x):
        for layer in self.layers:
            out = layer(x)
            if self.short_circuit:
                x = x + out
            else:
                x = out
        return x

In [94]:
# dim = [3, 3, 3, 3, 3, 1]  
dim = [100,100,100,100,100,100,100,100,1]
x = torch.rand(32,100,100)

In [101]:

def get_grad(model,x):
    # Forward pass
    output = model(x)
    target = torch.tensor([[0.]])
    
    # Calculate loss based on how close the target
    # and output are
    loss = nn.MSELoss()
    loss = loss(output, target)
    
    # Backward pass to calculate the gradients
    loss.backward()
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            # Print the mean absolute gradient of the weights
            print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")

In [102]:
# Next, let's print the gradient values with short_circuit connections:
get_grad( DummyNW(dim),x)

layers.0.0.weight has gradient mean of 0.027854595333337784
layers.1.0.weight has gradient mean of 0.027477923780679703
layers.2.0.weight has gradient mean of 0.02715429663658142
layers.3.0.weight has gradient mean of 0.027854714542627335
layers.4.0.weight has gradient mean of 0.03017064556479454
layers.5.0.weight has gradient mean of 0.032704733312129974
layers.6.0.weight has gradient mean of 0.030148791149258614
layers.7.0.weight has gradient mean of 1.4742094278335571


In [103]:

get_grad(DummyNW(dim,short_circuit=False),x)

layers.0.0.weight has gradient mean of 1.1193534277254003e-07
layers.1.0.weight has gradient mean of 8.519539562712453e-08
layers.2.0.weight has gradient mean of 9.869211226032348e-08
layers.3.0.weight has gradient mean of 2.509366083813802e-07
layers.4.0.weight has gradient mean of 7.714727985330683e-07
layers.5.0.weight has gradient mean of 3.048964117624564e-06
layers.6.0.weight has gradient mean of 1.2048118151142262e-05
layers.7.0.weight has gradient mean of 0.0005355386529117823


In [None]:
## we can see based on the output above, short_circuit connections prevent the gradients from vanishing in the early layers (towards layer.0)