In [1]:
import torch
import torch.nn as nn

In [2]:
torch.manual_seed(123)
batch_example = torch.randn(2, 5) ## Sample batch of 2 examples, each with 5 features
print(batch_example)

tensor([[-0.1115,  0.1204, -0.3696, -0.2404, -1.1969],
        [ 0.2093, -0.9724, -0.7550,  0.3239, -0.1085]])


In [None]:
layer = nn.Sequential(nn.Linear(5, 6), nn.ReLU()) ## One linear layer NN with 6 neurons, followed by ReLU activation
out = layer(batch_example)
print(out)

tensor([[0.2260, 0.3470, 0.0000, 0.2216, 0.0000, 0.0000],
        [0.2133, 0.2394, 0.0000, 0.5198, 0.3297, 0.0000]],
       grad_fn=<ReluBackward0>)


In [4]:
mean = out.mean(dim=-1, keepdim=True) ## Keepdim=True ensures the output shape is consistent for broadcasting, without it mean shape would be (2,) but we want (2, 1)
var = out.var(dim=-1, keepdim=True)
print("Mean:\n", mean)
print("Variance:\n", var)

Mean:
 tensor([[0.1324],
        [0.2170]], grad_fn=<MeanBackward1>)
Variance:
 tensor([[0.0231],
        [0.0398]], grad_fn=<VarBackward0>)


In [None]:
out_norm = (out - mean) / torch.sqrt(var) ## Normalizing the layer outputs to have zero mean and unit variance by subtracting the mean and dividing by the standard deviation. The mean and variance are computed across the last dimension (features) (row) for each example in the batch.
mean = out_norm.mean(dim=-1, keepdim=True)
var = out_norm.var(dim=-1, keepdim=True)
print("Normalized layer outputs:\n", out_norm)
print("Mean:\n", mean)
print("Variance:\n", var)

Normalized layer outputs:
 tensor([[ 0.6159,  1.4126, -0.8719,  0.5872, -0.8719, -0.8719],
        [-0.0189,  0.1121, -1.0876,  1.5173,  0.5647, -1.0876]],
       grad_fn=<DivBackward0>)
Mean:
 tensor([[-5.9605e-08],
        [ 1.9868e-08]], grad_fn=<MeanBackward1>)
Variance:
 tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)


In [6]:
## Layer Norm is applied to avoid vanishing/exploding gradients, especially in deep networks.
## It ensures that the inputs to each layer have a consistent scale, which helps in stabilizing the training process.
## This is particularly useful in deep networks where the distribution of inputs can change significantly
## as the data passes through multiple layers.

In [7]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5 ## To avoid division by zero incases where variance is zero
        self.scale = nn.Parameter(torch.ones(emb_dim)) ##Learnable Scale parameter to adjust the normalized output # Shape is (emb_dim,) so that it can be broadcasted across the input tensor depending on number of tokens
        self.shift = nn.Parameter(torch.zeros(emb_dim)) ## Learnable Shift parameter to adjust the normalized output # Shape is (emb_dim,) so that it can be broadcasted across the input tensor depending on number of tokens

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False) # Unbiased=False ensures that the variance is computed using the population formula and Bessel's correction is not applied (division by n and not n-1), which is more stable for small batch sizes
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift ## Element-wise multiplication by broadcasting the scale paarameter and then element-wise addition by broadcasting the shift parameter. ## The normalized output is scaled and shifted by the learnable parameters, which helps the model to learn the optimal representation of the data.

In [8]:
ln = LayerNorm(emb_dim=5)
out_ln = ln(batch_example)
mean = out_ln.mean(dim=-1, keepdim=True)
var = out_ln.var(dim=-1, unbiased=False, keepdim=True)
print("Mean:\n", mean)
print("Variance:\n", var)

Mean:
 tensor([[-2.9802e-08],
        [ 0.0000e+00]], grad_fn=<MeanBackward1>)
Variance:
 tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)
