In [1]:
import torch
from torch import nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [4]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i + 1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [6]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [7]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [8]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [9]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [10]:
out = gamma * y + beta
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [11]:
class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape = parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f'Mean \n ({mean.size()}): \n {mean}')
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f'Standard Deviation \n ({std.size()}): \n {std}')
        y = (inputs - mean) / std
        print(f'y \n ({y.size()}) = \n {y}')
        out = self.gamma * y + self.beta
        print(f'out \n ({out.size()}) = \n {out}')
        return out

In [12]:
batch_size = 3
sentence_length = 5
embedding_size = 8
inputs = torch.randn(sentence_length, batch_size, embedding_size)

print(f'input \n ({inputs.size()}) = \n {inputs}')

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 8.0393e-01,  1.8093e+00, -3.5793e-01,  2.8944e-01,  1.4446e-01,
          -7.0954e-01, -4.4128e-01,  9.3870e-01],
         [-3.4135e-01,  1.0489e+00, -1.1933e+00,  3.1396e+00,  1.0701e+00,
           9.7081e-01, -1.0244e+00,  2.1878e-02],
         [ 6.2214e-01,  3.3450e-01,  1.6957e-03, -5.2464e-02, -2.0262e+00,
           2.8799e-01, -8.9059e-01,  3.9550e-01]],

        [[ 2.2467e-01, -2.2753e-01, -1.6972e-01,  1.8350e+00, -6.7885e-02,
           2.9816e+00, -1.2675e+00,  1.7948e+00],
         [-1.5649e-01,  4.7979e-01,  1.3002e+00, -5.8206e-01,  1.3290e+00,
          -4.5023e-01, -1.3883e-01, -6.2965e-01],
         [-1.3336e-01, -2.7789e-01, -7.0325e-01,  9.7761e-01, -3.2154e-01,
           7.5650e-02, -3.9790e-01,  2.8715e-01]],

        [[ 7.9584e-01,  1.9072e+00,  1.3734e+00, -2.7367e-01, -1.3535e+00,
          -4.9655e-01, -1.7122e+00,  1.2417e+00],
         [-1.0906e-01, -1.8074e-02, -8.0288e-01, -8.9429e-01,  9.2369e-02,
          

In [13]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [14]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.3096],
         [ 0.4615],
         [-0.1659]],

        [[ 0.6379],
         [ 0.1440],
         [-0.0617]],

        [[ 0.1853],
         [-0.0418],
         [-0.3802]],

        [[-0.4465],
         [-0.2084],
         [-0.4525]],

        [[-0.2650],
         [ 0.1627],
         [ 0.3341]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.7878],
         [1.3196],
         [0.8231]],

        [[1.3212],
         [0.7503],
         [0.4828]],

        [[1.2505],
         [0.5554],
         [1.0402]],

        [[0.6970],
         [1.0599],
         [0.5062]],

        [[0.7978],
         [1.0787],
         [1.2525]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.6274,  1.9036, -0.8474, -0.0256, -0.2097, -1.2937, -0.9532,
           0.7985],
         [-0.6084,  0.4451, -1.2540,  2.0294,  0.4612,  0.3859, -1.1260,
          -0.3332],
         [ 0.9575,  0.6080,  0.2037,  0.1379, -2.2602,  0.5515, -0.8804,
           0.6821]],



In [15]:
out[0].mean(), out[0].std()

(tensor(9.9341e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))