### Layer Normalization

- Layer normalisation is helful as it prevent problems such as vainishing gradients, and makes the training efficient  
- When we use layer normalisation,the mean and std deviation of output from each layer is fixed and hence this is helfpul in having a efficeint training and prevents internal covariate shift 
- What exactly is internal covariate shift? 
     - Internal covariate shift is when the distribution of the input changes with the layer. 
     - This is problematic as it causes the network to have to readjust the weights to compensate for the change in distribution. 
     - This slows down the training process and makes it less efficient. 
     - Layer normalisation helps to prevent this by standardising the input to each layer. 
     - This makes the training process more efficient and faster. 
     - It also helps to prevent the problem of vanishing gradients. 


In [3]:
import torch


class LayerNorm(torch.nn.Module):

    def __init__(self,emd_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = torch.ones(emd_dim)
        self.shift = torch.zeros(emd_dim)

    def forward(self,x):
        mean = x.mean(dim=-1,keepdim=True)
        std = x.std(dim=-1,keepdim=True,unbiased=False) # unbiased=False, it uses division by N instead of N-1
        return self.scale * (x - mean) / (std + self.eps) + self.shift
    
    ## The scale and shift are learnable parameters that are used to scale and shift the input 
    # to the layer. 

In [4]:
torch.manual_seed(123)
batch_example  = torch.randn(2,5)
layer = torch.nn.Sequential(torch.nn.Linear(5,6),torch.nn.ReLU())
out = layer(batch_example)
print(out)


tensor([[0.2260, 0.3470, 0.0000, 0.2216, 0.0000, 0.0000],
        [0.2133, 0.2394, 0.0000, 0.5198, 0.3297, 0.0000]],
       grad_fn=<ReluBackward0>)


In [5]:
ln = LayerNorm(5)
out = ln(batch_example)
print(out)

tensor([[ 0.5528,  1.0693, -0.0223,  0.2656, -1.8654],
        [ 0.9087, -1.3767, -0.9564,  1.1304,  0.2940]])


In [10]:
## verify the output
mean = out.mean(dim=-1,keepdim=True)
std = out.std(dim=-1,keepdim=True,unbiased=False)

print(f"mean: {mean}\n std: {std}")

mean: tensor([[-2.9802e-08],
        [ 0.0000e+00]])
 std: tensor([[1.0000],
        [1.0000]])
