In [1]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = input.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((input - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (input - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [2]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.1095,  1.7194,  0.5795, -0.4288,  0.1096, -0.4119, -0.7873,
           0.4664],
         [-0.0127,  1.3077,  0.4556, -0.2021, -0.1953, -2.7125, -1.4945,
          -1.7005],
         [ 1.5420,  0.5728,  1.0544,  0.1280, -0.4420, -0.3686, -0.0539,
          -0.3755]],

        [[-1.4820,  1.5107, -0.9410, -0.8733,  0.9753, -1.9216, -0.0901,
           0.0429],
         [-0.5327,  0.0507, -0.8683, -1.3644,  1.9466, -1.7801,  0.7542,
          -0.3274],
         [ 0.5975, -0.9272, -0.9207,  1.0137, -0.6774, -1.4493,  0.8024,
           0.4387]],

        [[-0.2682, -1.5475,  0.3288, -0.5275, -0.1829, -0.5980,  0.9688,
          -0.0569],
         [-0.9837,  0.7308, -0.5428,  0.2721, -0.6919,  0.6033, -1.2605,
           1.1337],
         [-0.5892, -0.2567,  0.9135,  1.0962, -0.1753, -0.9857, -1.0709,
           0.2183]],

        [[-0.6400, -0.6171, -0.0481, -0.9367,  1.0695,  0.1658, -0.7029,
          -0.1045],
         [ 0.8413,  0.2713, 

In [3]:
layer_norm = LayerNormalization(inputs.size()[-1:])


In [4]:
out = layer_norm.forward(inputs)


Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.0172],
         [-0.5693],
         [ 0.2571]],

        [[-0.3474],
         [-0.2652],
         [-0.1403]],

        [[-0.2354],
         [-0.0924],
         [-0.1062]],

        [[-0.2267],
         [ 0.2248],
         [-0.0031]],

        [[-0.2302],
         [-0.0641],
         [-0.4670]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.8459],
         [1.2189],
         [0.6874]],

        [[1.1076],
         [1.1165],
         [0.8895]],

        [[0.6840],
         [0.8306],
         [0.7546]],

        [[0.6057],
         [0.7086],
         [0.6178]],

        [[1.1878],
         [0.8677],
         [0.9743]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.3320,  2.0125,  0.6648, -0.5273,  0.1093, -0.5072, -0.9511,
           0.5310],
         [ 0.4566,  1.5398,  0.8408,  0.3013,  0.3068, -1.7582, -0.7590,
          -0.9280],
         [ 1.8691,  0.4592,  1.1598, -0.1879, -1.0171, -0.9103, -0.4525,
          -0.9203]],



In [5]:
out[0].mean(), out[0].std()

(tensor(0., grad_fn=<MeanBackward0>), tensor(1.0215, grad_fn=<StdBackward0>))