In [1]:
import torch
from torch import nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S,B,E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [4]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i+1) for i in range(len(parameter_shape))]

In [6]:
dims

[-1, -2]

In [7]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size(), mean

(torch.Size([2, 1, 1]),
 tensor([[[0.2000]],
 
         [[0.2333]]]))

In [8]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [9]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [11]:
out = gamma * y + beta
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [12]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [13]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.2955e+00,  2.0290e+00, -6.5452e-01,  1.0763e+00, -6.2481e-01,
           8.2222e-01, -1.4164e+00,  9.2760e-01],
         [-1.8057e-01,  3.1844e-01, -9.9258e-01,  9.2753e-01, -2.8802e-01,
          -1.5813e+00,  1.1965e+00,  7.6290e-01],
         [-2.1411e-01, -4.3660e-01, -4.7695e-01, -1.6136e-01, -8.1425e-01,
          -9.3295e-01, -1.7188e-01, -1.5605e-01]],

        [[-7.4638e-01, -1.6250e+00, -9.0585e-01,  1.4066e+00,  2.1901e-01,
           4.5005e-01, -2.0215e+00, -1.0158e+00],
         [-2.2032e-01, -5.0617e-01,  2.4300e+00,  2.6013e-01,  1.5337e+00,
           4.3804e-01, -5.1026e-01,  3.0526e-01],
         [-7.3901e-01,  6.0443e-01, -7.4369e-02,  3.9546e-01,  2.0363e+00,
           9.0536e-01,  3.1408e-01, -1.4544e+00]],

        [[ 1.9478e+00, -4.9073e-01,  1.4940e+00,  9.3957e-01, -6.9595e-01,
           1.0196e+00,  2.1303e-01, -1.1830e+00],
         [-1.4072e-01,  1.2112e+00,  4.8439e-01,  3.0406e-01, -4.8478e-01,
          

In [14]:
layer_norm = LayerNormalization(inputs.size()[-1:])
out = layer_norm.forward(inputs)


Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.1080],
         [ 0.0204],
         [-0.4205]],

        [[-0.5298],
         [ 0.4663],
         [ 0.2485]],

        [[ 0.4055],
         [-0.1434],
         [ 0.0787]],

        [[ 0.3010],
         [ 0.3854],
         [ 0.4310]],

        [[ 0.2585],
         [ 0.2049],
         [ 0.4593]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.1844],
         [0.9065],
         [0.2878]],

        [[1.0677],
         [0.9641],
         [0.9849]],

        [[1.0490],
         [0.9506],
         [0.7404]],

        [[1.0073],
         [0.4700],
         [0.9460]],

        [[1.1545],
         [1.1647],
         [0.5655]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.1850,  1.6220, -0.6438,  0.8175, -0.6187,  0.6031, -1.2871,
           0.6920],
         [-0.2217,  0.3288, -1.1175,  1.0008, -0.3402, -1.7669,  1.2975,
           0.8192],
         [ 0.7172, -0.0559, -0.1961,  0.9005, -1.3680, -1.7805,  0.8639,
           0.9189]],



In [15]:
out[0].mean(), out[0].std()

(tensor(9.9341e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))