<a href="https://colab.research.google.com/github/AyoubMDL/transformers_from_scratch/blob/main/layer_normalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn   

In [4]:
inputs = torch.Tensor([[[0.1, 0.6, 0.8], [0.4, 0.5, 0.1]]])
inputs.shape

torch.Size([1, 2, 3])

In [10]:
batch_size, seq_length, emb_size = inputs.size()

In [11]:
inputs = inputs.reshape(seq_length, batch_size, emb_size)
inputs.shape

torch.Size([2, 1, 3])

In [12]:
parameter_shape = inputs.size()[-2:]
parameter_shape

torch.Size([1, 3])

In [13]:
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [14]:
gamma, beta

(Parameter containing:
 tensor([[1., 1., 1.]], requires_grad=True),
 Parameter containing:
 tensor([[0., 0., 0.]], requires_grad=True))

In [15]:
gamma.shape

torch.Size([1, 3])

In [16]:
dims = [-(i + 1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [18]:
mean = inputs.mean(dim=dims, keepdim=True)
mean

tensor([[[0.5000]],

        [[0.3333]]])

In [20]:
mean.shape

torch.Size([2, 1, 1])

In [21]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.2944]],

        [[0.1700]]])

In [22]:
y = (inputs - mean) / std
y

tensor([[[-1.3587,  0.3397,  1.0190]],

        [[ 0.3922,  0.9804, -1.3726]]])

In [23]:
out = gamma * y + beta

In [24]:
out

tensor([[[-1.3587,  0.3397,  1.0190]],

        [[ 0.3922,  0.9804, -1.3726]]], grad_fn=<AddBackward0>)

In [36]:
class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape = parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(parameter_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta
        
        return out

In [37]:
batch_size = 5
seq_length = 4
emb_dim = 6

inputs = torch.randn(seq_length, batch_size, emb_dim)
inputs.shape

torch.Size([4, 5, 6])

In [38]:
ln = LayerNormalization(inputs.size()[-2:])

In [39]:
out = ln(inputs)

In [40]:
out

tensor([[[-8.0180e-01,  1.6604e+00, -3.8528e-01, -1.3219e+00, -2.6988e+00,
          -1.5344e+00],
         [-1.5279e+00,  2.9466e-02,  5.5665e-01, -6.5321e-01,  9.3418e-01,
          -9.5731e-02],
         [ 9.3529e-01,  6.7010e-01,  1.2162e+00,  9.3067e-04,  7.4975e-01,
           9.5986e-01],
         [ 1.0366e+00,  2.4387e-01, -1.0799e+00, -8.2493e-01, -4.3667e-01,
          -6.7650e-01],
         [ 1.5142e-01,  1.1248e+00, -2.8562e-01,  9.3403e-01,  1.5441e-01,
           9.6482e-01]],

        [[ 7.8759e-02,  1.2339e-01,  6.3382e-01,  3.8799e-02, -1.8610e-01,
           7.9666e-01],
         [-2.4091e-01, -6.3703e-01, -3.1361e-01,  8.7458e-01, -6.1336e-01,
           5.5686e-01],
         [-3.7986e-01, -1.5462e-03,  1.2705e+00, -8.2688e-01, -5.0351e-01,
          -2.6533e+00],
         [ 3.9346e-01,  3.0484e+00, -4.1172e-01, -3.8634e-01,  1.7481e+00,
          -6.7884e-01],
         [ 1.6978e-01,  6.1485e-01, -7.6977e-01, -1.0446e+00, -1.2272e+00,
           5.2658e-01]],

      

In [41]:
out.shape

torch.Size([4, 5, 6])