In [1]:
import torch
import torch.nn as nn

In [3]:
inputs = torch.Tensor([[[0.2, 0.2, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [5]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [6]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [7]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [8]:
dims

[-1, -2]

In [9]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [10]:
mean

tensor([[[0.2333]],

        [[0.2333]]])

In [11]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0472]],

        [[0.1886]]])

In [12]:
y = (inputs - mean) / std
y

tensor([[[-0.7055, -0.7055,  1.4110]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [13]:
out = gamma * y + beta

In [14]:
out

tensor([[[-0.7055, -0.7055,  1.4110]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [18]:
import torch
import torch.nn as nn


class LayerNormalization():
    def __init__(self, parameter_shape, eps=1e-5):
        self.parameters_shape = parameter_shape
        self.eps = eps 
        self.gamma = nn.Parameter(torch.ones(self.parameters_shape))
        self.beta = nn.Parameter(torch.zeros(self.parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [19]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.3561,  0.2991, -1.4702, -0.6417,  0.3707, -0.1030, -0.1629,
          -0.4121],
         [ 0.9399,  1.2281, -0.9997, -0.9131, -1.3375,  0.2015,  0.6699,
          -0.2099],
         [-0.3110, -0.0057,  1.9998, -1.8400,  0.3293, -0.5319,  0.7293,
           1.8712]],

        [[ 1.5988,  0.0834,  0.0667,  0.3253, -0.0074,  1.4089, -0.1865,
           0.5196],
         [ 1.4813, -0.8957, -0.4193, -1.9212,  1.6492,  1.3907, -0.1602,
           1.2010],
         [ 0.0153,  0.1080, -2.0106, -1.6527, -0.9883,  0.0354, -0.3848,
          -0.4003]],

        [[ 0.7585,  1.0235,  2.9801,  0.2266,  1.6051,  0.1240, -0.5304,
          -0.7618],
         [-1.6509, -1.4932, -0.9333,  0.6900, -0.1239,  2.1946,  1.0553,
           1.4852],
         [-0.1474,  1.1900,  0.3185,  1.4951,  1.3868,  0.3076,  0.9422,
          -1.4083]],

        [[ 0.5537,  1.2516,  0.1959, -1.0799, -0.4477, -0.1156,  0.3852,
          -0.1058],
         [ 0.9783,  0.2999, 

In [20]:
layer_norm = LayerNormalization(inputs.size() [-2:])

In [22]:
out = layer_norm.forward(inputs)

Mean 
 (<built-in method size of Tensor object at 0x0000020E94691BC0>): 
 tensor([[[-0.0273]],

        [[ 0.0357]],

        [[ 0.4472]],

        [[ 0.2266]],

        [[-0.0804]]])
Standard Deviation 
 (torch.Size([5, 1, 1])): 
 tensor([[[0.9480]],

        [[1.0406]],

        [[1.1606]],

        [[0.8611]],

        [[1.0569]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-3.4675e-01,  3.4436e-01, -1.5220e+00, -6.4808e-01,  4.1987e-01,
          -7.9805e-02, -1.4303e-01, -4.0589e-01],
         [ 1.0203e+00,  1.3243e+00, -1.0257e+00, -9.3437e-01, -1.3820e+00,
           2.4137e-01,  7.3547e-01, -1.9253e-01],
         [-2.9924e-01,  2.2781e-02,  2.1383e+00, -1.9121e+00,  3.7621e-01,
          -5.3222e-01,  7.9816e-01,  2.0027e+00]],

        [[ 1.5021e+00,  4.5862e-02,  2.9779e-02,  2.7834e-01, -4.1411e-02,
           1.3196e+00, -2.1348e-01,  4.6500e-01],
         [ 1.3891e+00, -8.9502e-01, -4.3727e-01, -1.8805e+00,  1.5505e+00,
           1.3021e+00, -1.8826e-01,  1.1198e+00],
    