Import Libraries

In [1]:
import torch
import torch.nn as nn
import tiktoken

Example of layer normalization

In [2]:
torch.manual_seed(123)
batch_example = torch.randn(2,5)
layer = nn.Sequential(nn.Linear(5,6),nn.ReLU())
out = layer(batch_example)
print(out)

tensor([[0.2260, 0.3470, 0.0000, 0.2216, 0.0000, 0.0000],
        [0.2133, 0.2394, 0.0000, 0.5198, 0.3297, 0.0000]],
       grad_fn=<ReluBackward0>)


In [3]:
mean = out.mean(dim=-1, keepdim = True)
variance = out.var(dim = -1, keepdim =True)
print("Mean",mean)
print("variance",variance)

Mean tensor([[0.1324],
        [0.2170]], grad_fn=<MeanBackward1>)
variance tensor([[0.0231],
        [0.0398]], grad_fn=<VarBackward0>)


In [4]:
out_norm = (out -mean) / torch.sqrt(variance + 1e-5)
mean = out_norm.mean(dim=-1, keepdim = True)
variance = out_norm.var(dim = -1, keepdim =True)
print("Mean after normalization",mean)
print("variance after normalization",variance)

Mean after normalization tensor([[0.],
        [0.]], grad_fn=<MeanBackward1>)
variance after normalization tensor([[0.9996],
        [0.9997]], grad_fn=<VarBackward0>)


Implementation of Layer normalization class

In [5]:
class LayerNorm(nn.Module):
    def __init__(self,emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim = True)
        variance = x.var(dim = -1, keepdim =True)
        norm_x = (x-mean) / torch.sqrt(variance + self.eps)
        return self.scale * norm_x + self.shift
    


In [6]:
ln = LayerNorm(emb_dim = 5)
out_ln = ln(batch_example)
mean_ln = out_ln.mean(dim = -1, keepdim = True)
variance_ln = out_ln.var(dim = -1, keepdim = True, unbiased = False)
print("Mean after layer normalization", mean_ln)
print("Variance after layer normalization", variance_ln)



Mean after layer normalization tensor([[-1.4901e-08],
        [ 2.3842e-08]], grad_fn=<MeanBackward1>)
Variance after layer normalization tensor([[0.8000],
        [0.8000]], grad_fn=<VarBackward0>)
