In [1]:
import torch 
import torch.nn as nn

In [2]:
class Layer_Normalization(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        x_mean = x.mean(dim=-1, keepdim=True)
        x_var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x-x_mean)/torch.sqrt(x_var+self.eps)
        return self.scale*x_norm + self.shift

In [3]:
torch.manual_seed(123)
batch_eg = torch.randn(2,5)
batch_eg

tensor([[-0.1115,  0.1204, -0.3696, -0.2404, -1.1969],
        [ 0.2093, -0.9724, -0.7550,  0.3239, -0.1085]])

In [4]:
ln = Layer_Normalization(emb_dim=5)
output = ln(batch_eg)
mean = output.mean(dim=-1, keepdim=True)
var = output.var(dim=-1, keepdim=True, unbiased=False)
print(f"Mean:\n{mean}\nVarience:\n{var}")

Mean:
tensor([[-2.9802e-08],
        [ 0.0000e+00]], grad_fn=<MeanBackward1>)
Varience:
tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)
