<a href="https://colab.research.google.com/github/ajayvallabh/PytorchTutorial/blob/main/LayerNormalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

In [None]:
inputs = torch.Tensor([[[0.2,0.1,0.3],[0.5,0.1,0.1]]]) # batch input
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size() # Number of word 2, batch size 1, embedding the each batch = 3

torch.Size([2, 1, 3])

In [None]:
parameter_shape =inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [None]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [None]:
dims = [-(i+1) for i in range(len(parameter_shape))]

In [None]:
dims

[-1, -2]

In [None]:
mean = inputs.mean(dim = dims, keepdim = True)
mean.size()

torch.Size([2, 1, 1])

In [None]:
var = ((inputs-mean)**2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var+epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [None]:
y = (inputs-mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [None]:
out = gamma*y + beta

In [None]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

# Layer Normalization Class

In [None]:
import torch
from torch import nn

class LayerNormalization():
  def __init__(self, parameter_shape, eps=1e-5):
    self.parameter_shape = parameter_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameter_shape))
    self.beta = nn.Parameter(torch.zeros(parameter_shape))

  def forward(self, input):
    dims = [-(i+1) for i in range(len(self.parameter_shape))]
    mean = inputs.mean(dim=dims, keepdim=True)
    print(f"Mean \n ({mean.size()}): \n {mean}")
    var = ((input-mean)**2).mean(dim=dims, keepdim=True)
    std = (var + self.eps).sqrt()
    print(f"Standard Deviation \n ({std.size()}): \n {std}")
    y = (inputs-mean) / std
    print(f"y \n ({y.size()}) = \n {y}")
    out = self.gamma * y + self.beta
    print(f"out \n ({out.size()}) = \n {out}")
    return out

In [None]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = {inputs}")

input 
 (torch.Size([5, 3, 8])) = tensor([[[ 0.9547,  0.9566,  0.5345, -0.0652,  0.3878,  0.6297,  1.1684,
           0.1082],
         [-0.0361, -0.3104, -0.0898, -0.7286,  0.4172,  0.3980, -1.0196,
          -1.0880],
         [-0.0734, -0.7290,  0.4510,  0.1741, -0.5294,  0.8448,  0.1440,
          -1.6040]],

        [[ 0.1193, -0.1293,  0.5652, -0.4637,  0.1900,  0.1970,  0.2726,
           0.7118],
         [-1.3196, -0.0802, -0.2023, -0.5171, -0.8200,  1.4979,  0.8022,
          -1.1153],
         [-2.2353, -0.5812,  2.6042, -1.3106, -0.2852, -1.4256, -1.5638,
          -0.9875]],

        [[-0.3183,  0.5551,  0.1410, -0.8474, -0.0375,  2.3172,  0.5033,
           0.3971],
         [-0.4906, -0.6728, -0.0757, -1.1786, -1.2865, -2.0180, -0.3315,
          -1.1942],
         [-0.5160,  0.0414, -0.9908,  0.9247,  2.0534, -0.4055, -0.8212,
          -1.0479]],

        [[-1.3799, -0.9166, -1.7157,  0.0675,  0.8621, -0.4258,  0.1571,
           0.4993],
         [ 1.2234,  0.1930,  0

In [None]:
layer_norm = LayerNormalization(inputs.size()[-2:])

In [None]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 1, 1])): 
 tensor([[[ 0.0373]],

        [[-0.2532]],

        [[-0.2208]],

        [[ 0.0046]],

        [[ 0.0673]]])
Standard Deviation 
 (torch.Size([5, 1, 1])): 
 tensor([[[0.6939]],

        [[1.0390]],

        [[0.9868]],

        [[0.8438]],

        [[0.8855]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 1.3220,  1.3247,  0.7165, -0.1477,  0.5050,  0.8536,  1.6299,
           0.1022],
         [-0.1058, -0.5010, -0.1832, -1.1037,  0.5474,  0.5198, -1.5231,
          -1.6216],
         [-0.1596, -1.1043,  0.5962,  0.1971, -0.8167,  1.1636,  0.1538,
          -2.3652]],

        [[ 0.3585,  0.1192,  0.7876, -0.2026,  0.4265,  0.4333,  0.5061,
           0.9287],
         [-1.0264,  0.1665,  0.0489, -0.2540, -0.5455,  1.6852,  1.0158,
          -0.8297],
         [-1.9076, -0.3157,  2.7500, -1.0177, -0.0308, -1.1284, -1.2613,
          -0.7067]],

        [[-0.0988,  0.7863,  0.3666, -0.6350,  0.1857,  2.5720,  0.7338,
           0.6262],
         [-0.273