In [2]:
import torch
from torch import nn

In [5]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S,B,E)
inputs.size()

torch.Size([2, 1, 3])

In [6]:
parameter_shape = inputs.size()[-2:]
# the lr parameters
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [7]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [8]:
dims = [-(i+1)  for i in range(len(parameter_shape))]

In [9]:
dims

[-1, -2]

In [10]:
mean = inputs.mean(dim = dims, keepdim = True)
mean.size()

torch.Size([2, 1, 1])

In [11]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [12]:
var = ((inputs- mean) **2).mean(dim = dims, keepdim = True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std


tensor([[[0.0817]],

        [[0.1886]]])

In [13]:
 y = (inputs- mean) / std
 y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [14]:
out = gamma *y + beta
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

**Wrappping all in a Class**

In [24]:
import torch
from torch import nn

class Layer_Normalisation():
  def __init__(self, parameters_shape, eps=1e-5):
    self.parameters_shape = parameters_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameters_shape))
    self.beta = nn.Parameter(torch.zeros(parameters_shape))

  def forward(self, input):
    dims = [-(i+1) for i in range(len(self.parameters_shape))]
    mean = inputs.mean(dim=dims, keepdim= True)
    print(f"Mean \n ({mean.size()}): \n {mean}")

    var = ((inputs - mean)**2).mean(dim=dims, keepdim = True)
    std = (var+self.eps).sqrt()
    print(f"Standard Deviation \n ({std.size()}): \n {std}")

    y = (inputs-mean)/std
    print(f"y \n ({y.size()}) = \n {y}")

    out = self.gamma*y + self.beta
    print(f"out \n ({out.size()}) = \n {out}")
    return out


In [25]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.1078e-01, -1.9829e+00, -1.8426e-01,  1.3140e-01, -1.1800e-01,
           5.7795e-01,  1.6297e+00,  5.7278e-01],
         [-1.1299e+00,  1.8223e+00, -3.7062e-01, -7.4907e-01, -1.0050e+00,
          -1.1695e-01, -6.0395e-01, -2.2386e+00],
         [ 3.1496e-01,  4.7026e-01, -1.5214e-02, -7.4039e-01, -2.6066e-01,
           5.2195e-01,  4.7541e-01,  6.8642e-01]],

        [[ 2.6270e+00,  1.3458e+00, -1.8171e+00,  1.0486e+00,  1.5161e+00,
           8.5030e-02, -1.3066e-01, -1.4093e-01],
         [-8.9171e-01,  6.5222e-01,  1.3927e-01,  1.4040e+00,  9.9691e-01,
          -1.1896e+00, -6.3867e-01, -3.9875e-01],
         [-1.7499e+00, -6.5021e-01, -3.9976e-01,  3.6300e-02, -7.0304e-01,
          -4.0278e-01, -4.3826e-02,  6.1389e-02]],

        [[ 9.5991e-01, -8.3396e-01, -3.1445e-01,  1.8404e-01,  1.6657e+00,
          -8.3050e-01,  2.4855e-01,  2.6059e-01],
         [-2.2159e+00,  6.4653e-02,  8.6292e-02, -3.7711e-01, -1.0700e+00,
          

In [28]:
layer_norm = Layer_Normalisation(inputs.size()[-2:])

In [29]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 1, 1])): 
 tensor([[[-0.1010]],

        [[ 0.0315]],

        [[-0.1482]],

        [[ 0.3094]],

        [[-0.1209]]])
Standard Deviation 
 (torch.Size([5, 1, 1])): 
 tensor([[[0.9284]],

        [[1.0394]],

        [[0.9122]],

        [[1.2455]],

        [[0.7916]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.0106, -2.0271, -0.0897,  0.2503, -0.0183,  0.7313,  1.8642,
           0.7257],
         [-1.1084,  2.0716, -0.2905, -0.6981, -0.9738, -0.0172, -0.5418,
          -2.3026],
         [ 0.4480,  0.6153,  0.0924, -0.6888, -0.1720,  0.6710,  0.6209,
           0.8481]],

        [[ 2.4971,  1.2645, -1.7785,  0.9785,  1.4283,  0.0515, -0.1560,
          -0.1659],
         [-0.8882,  0.5972,  0.1037,  1.3205,  0.9288, -1.1748, -0.6448,
          -0.4139],
         [-1.7138, -0.6559, -0.4149,  0.0046, -0.7067, -0.4178, -0.0725,
           0.0288]],

        [[ 1.2148, -0.7517, -0.1822,  0.3643,  1.9886, -0.7479,  0.4350,
           0.4482],
         [-2.266

In [30]:
out[0].mean(), out[0].std()

(tensor(2.4835e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))