Layer Normalization = [x' + output]

In [1]:
import torch
from torch import nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [4]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [6]:
dims

[-1, -2]

In [8]:
mean = inputs.mean(dim = dims, keepdim = True)
mean.size()

torch.Size([2, 1, 1])

In [9]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [10]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [11]:
out = gamma * y + beta

In [12]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [17]:
class LayerNormalization():
  def __init__(self, parameters_shape, eps=1e-5):
    self.parameter_shape = parameter_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameters_shape))
    self.beta = nn.Parameter(torch.zeros(parameters_shape))

  def forward(self, input):
    dims = [-(i + 1) for i in range(len(self.parameter_shape))]
    mean = inputs.mean(dim=dims, keepdim=True)
    print(f"Mean \n ({mean.size()}): \n  {mean}")
    var = ((input - mean) ** 2).mean(dim=dims, keepdim=True)
    std = (var + self.eps).sqrt()
    print(f"Standard Deviation \n ({std.size()}): \n {std}")
    y = (inputs - mean) / std
    print(f"y \n ({y.size()}): \n {y}")
    out = self.gamma * y + self.beta
    print(f"Out \n ({out.size()}): \n {out}")
    return out

In [14]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"Input \n ({inputs.size()}): \n {inputs}")

Input 
 (torch.Size([5, 3, 8])): 
 tensor([[[-1.3552,  0.1461,  1.1616, -0.9327,  0.1313,  0.3248,  1.1430,
           1.0944],
         [-2.7218, -0.3241, -0.8396, -0.6377,  1.6304,  0.1386, -1.4228,
           0.9892],
         [ 0.2353, -0.2593,  0.6480,  0.5153, -0.4421, -0.2659,  0.4064,
          -1.3005]],

        [[-0.9104, -0.7916,  1.0208, -0.1001, -1.8537,  0.5363, -1.6168,
          -0.0342],
         [ 0.8325,  0.8022, -0.4090,  0.6711, -1.8140,  0.5973, -0.2788,
           1.0930],
         [ 0.3963,  0.4892,  1.4349,  1.6694, -0.8572,  1.4201, -0.4561,
          -0.7536]],

        [[ 0.0875, -1.3106, -1.0665, -1.4315,  1.5776,  0.1664,  0.7757,
          -0.0661],
         [ 1.1323,  0.6356,  2.0684,  1.0474,  0.7554, -0.9763,  0.4924,
          -1.1262],
         [ 0.3728, -1.0713,  0.0627,  0.3622,  0.2820, -1.3792,  0.2210,
          -0.9726]],

        [[ 0.9175, -1.3852, -0.1717, -1.7987, -0.4800,  1.5262,  0.1802,
          -0.1888],
         [-0.3621, -0.5866,  

In [18]:
layer_norm = LayerNormalization(inputs.size()[-2:])

In [19]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 1, 1])): 
  tensor([[[-0.0807]],

        [[ 0.0453]],

        [[ 0.0266]],

        [[ 0.2011]],

        [[ 0.6337]]])
Standard Deviation 
 (torch.Size([5, 1, 1])): 
 tensor([[[0.9955]],

        [[1.0091]],

        [[0.9685]],

        [[0.8332]],

        [[0.7973]]])
y 
 (torch.Size([5, 3, 8])): 
 tensor([[[-1.2803e+00,  2.2785e-01,  1.2480e+00, -8.5583e-01,  2.1294e-01,
           4.0737e-01,  1.2293e+00,  1.1805e+00],
         [-2.6530e+00, -2.4444e-01, -7.6229e-01, -5.5947e-01,  1.7189e+00,
           2.2028e-01, -1.3482e+00,  1.0747e+00],
         [ 3.1748e-01, -1.7942e-01,  7.3203e-01,  5.9868e-01, -3.6301e-01,
          -1.8606e-01,  4.8931e-01, -1.2253e+00]],

        [[-9.4705e-01, -8.2934e-01,  9.6664e-01, -1.4407e-01, -1.8818e+00,
           4.8649e-01, -1.6471e+00, -7.8757e-02],
         [ 7.8002e-01,  7.4999e-01, -4.5020e-01,  6.2014e-01, -1.8425e+00,
           5.4701e-01, -3.2116e-01,  1.0382e+00],
         [ 3.4786e-01,  4.3984e-01,  1.3770e