In [1]:
import torch
from torch import nn

In [5]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [11]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [12]:
gamma.size(), beta.size() # For batch 

(torch.Size([1, 3]), torch.Size([1, 3]))

In [13]:
dims = [-(i + 1) for i in range(len(parameter_shape))] # Batch + layer 

In [14]:
dims

[-1, -2]

In [15]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [16]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [17]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [18]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [19]:
out = gamma * y + beta

In [20]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

## Class

In [24]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [25]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.4488, -1.0678,  1.0113, -0.2998,  0.1561, -0.2554, -0.7329,
           2.2027],
         [ 0.5831, -0.5719,  0.7693, -0.7760, -0.4094,  0.1170,  0.9339,
           0.5682],
         [ 0.3259, -0.7332, -0.5762,  0.0039, -0.3591, -0.2478, -0.9618,
          -0.5951]],

        [[ 0.3982,  1.2086,  2.1620,  0.3764, -0.0401, -0.1729,  1.9406,
          -1.3515],
         [-0.0076,  0.5875, -1.3514, -2.0623,  0.8407,  0.3298, -1.0934,
           0.9744],
         [ 0.1474, -0.1491, -0.6619, -0.5229,  0.4628, -0.3913,  0.0224,
           0.7849]],

        [[-0.6184, -0.8950,  0.3012,  1.0319,  0.8634,  0.5250,  0.8652,
          -0.2641],
         [ 0.3830, -0.7671, -0.8941, -0.0045, -1.3029,  0.5734,  0.7833,
          -0.7207],
         [-1.3157,  0.7573,  0.8301, -0.2392,  1.4416, -0.6874, -1.4176,
          -0.9247]],

        [[ 0.6812,  0.2190,  1.0782,  0.7558,  0.3581,  1.8396, -1.5581,
          -0.5959],
         [ 0.5889,  0.0339, 

In [26]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [27]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.0707],
         [ 0.1518],
         [-0.3929]],

        [[ 0.5651],
         [-0.2228],
         [-0.0385]],

        [[ 0.2262],
         [-0.2437],
         [-0.1944]],

        [[ 0.3472],
         [-0.0137],
         [ 0.9010]],

        [[-0.5066],
         [ 0.2137],
         [-0.0023]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.9933],
         [0.6178],
         [0.3890]],

        [[1.0900],
         [1.0600],
         [0.4647]],

        [[0.6864],
         [0.7262],
         [1.0102]],

        [[0.9732],
         [0.7215],
         [1.0240]],

        [[0.7261],
         [0.4401],
         [1.2250]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.5230, -1.1462,  0.9470, -0.3730,  0.0860, -0.3283, -0.8090,
           2.1464],
         [ 0.6982, -1.1714,  0.9996, -1.5018, -0.9084, -0.0563,  1.2659,
           0.6741],
         [ 1.8480, -0.8749, -0.4711,  1.0203,  0.0870,  0.3730, -1.4626,
          -0.5197]],



In [28]:
out[0].mean(), out[0].std()

(tensor(0., grad_fn=<MeanBackward0>), tensor(1.0215, grad_fn=<StdBackward0>))