In [1]:
import torch
from torch import nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [4]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [6]:
dims

[-1, -2]

In [7]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [8]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [9]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [10]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [11]:
out = gamma * y + beta

In [12]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

## Class

In [13]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [14]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 1.2043e+00,  9.7202e-01, -2.2918e+00, -1.0965e+00,  1.1829e+00,
           4.9697e-01,  3.1783e-01, -7.7466e-01],
         [ 6.6251e-01,  1.6642e+00,  1.1274e-01,  8.6929e-01,  6.7051e-01,
          -4.4236e-02,  8.6410e-01,  1.9468e+00],
         [ 4.2376e-01, -4.4193e-01, -1.2858e+00, -4.4292e-01,  3.5721e-01,
          -3.6774e-01, -1.6730e-01,  5.3514e-01]],

        [[-2.7082e-02,  6.2048e-01,  1.8167e+00,  2.1104e+00,  1.5519e+00,
           3.8018e-01,  7.0886e-02,  8.6067e-01],
         [-4.7546e-01, -9.8065e-01, -9.5750e-01,  1.5417e-01, -1.3559e-02,
           7.7145e-01, -1.4337e+00,  1.6206e+00],
         [ 2.9650e+00, -2.0598e+00, -7.8759e-01,  4.4774e-01,  1.4893e+00,
          -2.1081e+00,  8.3458e-01, -7.9242e-01]],

        [[ 1.1495e+00,  5.1507e-04, -4.4897e-01, -7.1660e-01, -8.9793e-01,
          -1.3095e+00, -8.1508e-01, -6.4028e-02],
         [ 1.5405e+00, -5.5251e-01,  1.0124e+00,  9.9804e-01, -1.0709e+00,
          

In [15]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [16]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.0014],
         [ 0.8432],
         [-0.1737]],

        [[ 0.9230],
         [-0.1643],
         [-0.0014]],

        [[-0.3878],
         [ 0.6220],
         [-0.0475]],

        [[ 0.2818],
         [-0.3927],
         [ 0.1707]],

        [[ 0.3641],
         [-0.7379],
         [ 0.6484]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.1831],
         [0.6400],
         [0.5657]],

        [[0.7602],
         [0.9480],
         [1.6523]],

        [[0.7092],
         [1.0042],
         [0.8793]],

        [[0.8898],
         [1.2780],
         [0.8513]],

        [[0.4089],
         [0.9914],
         [0.7112]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 1.0168,  0.8204, -1.9383, -0.9280,  0.9987,  0.4189,  0.2675,
          -0.6560],
         [-0.2824,  1.2828, -1.1414,  0.0407, -0.2699, -1.3867,  0.0326,
           1.7243],
         [ 1.0561, -0.4742, -1.9658, -0.4759,  0.9385, -0.3430,  0.0113,
           1.2530]],



In [17]:
out[0].mean(), out[0].std()

(tensor(9.9341e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))