In [1]:
import torch
from torch import nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [4]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i + 1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [6]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [7]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [8]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [9]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [10]:
out = gamma * y + beta

In [11]:
out  

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

### Class

In [12]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [13]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.3511, -0.4669, -1.2817,  1.3054, -0.5906,  0.6535, -1.2926,
          -0.3531],
         [-0.6564,  0.7057,  0.2373, -1.3062,  1.2162,  0.3847, -1.6895,
          -0.1612],
         [ 1.5076,  0.0189,  1.1641, -0.4137, -0.0999, -1.2410,  3.2914,
          -0.8798]],

        [[ 0.2357, -0.6373, -0.3988,  0.5782, -0.7712,  1.7202, -0.3875,
           0.6682],
         [-0.8660,  0.5063,  1.0012,  0.1694, -1.3982, -2.0103, -1.5528,
           0.4724],
         [ 1.0830,  0.2584,  0.0521, -0.7756,  0.0717, -0.6323,  1.3583,
          -2.1740]],

        [[-0.0580, -1.6516, -0.6851, -0.0530,  0.1656,  0.6189, -0.6057,
           1.6657],
         [ 0.7941,  0.3811,  0.1423,  0.7261,  1.2150,  1.1154, -0.9907,
          -0.5858],
         [ 1.7863,  0.3841,  0.3413,  0.2212,  0.4596,  2.1709,  0.6000,
           0.3923]],

        [[-0.2713, -1.4483,  0.9031, -0.4365, -1.7948,  0.7539,  1.9296,
          -1.4880],
         [ 0.7107,  0.9716, 

In [14]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [15]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[-0.4221],
         [-0.1587],
         [ 0.4184]],

        [[ 0.1259],
         [-0.4597],
         [-0.0948]],

        [[-0.0754],
         [ 0.3497],
         [ 0.7945]],

        [[-0.2315],
         [ 0.2579],
         [ 0.0772]],

        [[-0.2978],
         [-0.3594],
         [ 0.5861]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.9039],
         [0.9352],
         [1.3945]],

        [[0.7889],
         [1.0593],
         [1.0457]],

        [[0.9158],
         [0.7408],
         [0.6975]],

        [[1.2488],
         [0.5605],
         [1.0008]],

        [[1.2506],
         [0.7316],
         [1.0340]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.0277, -0.0495, -0.9510,  1.9112, -0.1864,  1.1900, -0.9630,
           0.0764],
         [-0.5322,  0.9243,  0.4234, -1.2270,  1.4700,  0.5810, -1.6368,
          -0.0027],
         [ 0.7810, -0.2865,  0.5347, -0.5968, -0.3717, -1.1900,  2.0603,
          -0.9310]],



In [16]:
out[0].mean(), out[0].std()

(tensor(2.4835e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))