In [1]:
import torch
from torch import nn

In [2]:
inputs = torch.tensor([[[0.2,0.1,0.3],[0.5,0.1,0.1]]])
B,S,E = inputs.size()
inputs = inputs.reshape(S,B,E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.size()[-2:]
parameter_shape

torch.Size([1, 3])

In [4]:
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [5]:
dims = [-(i+1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [6]:
mean = inputs.mean(dim=dims,keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [7]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [8]:
var = ((inputs-mean)**2).mean(dim=dims,keepdim=True)
eps = 1e-5
std = (var+eps).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [9]:
y = (inputs-mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [10]:
out = gamma*y + beta
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [16]:
class LayerNormalization():
  def __init__(self,parameters_shape,eps=1e-5):
    self.parameters_shape = parameters_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameters_shape))
    self.beta = nn.Parameter(torch.zeros(parameters_shape))

  def forward(self,input):
    dims = [-(i+1) for i in range(len(self.parameters_shape))]
    mean = inputs.mean(dim=dims,keepdim=True)
    print(f"Mean \n ({mean.size()}): \n {mean}")
    var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
    std = (var + self.eps).sqrt()
    print(f"Standard Deviation \n ({std.size()}): \n {std}")
    y = (inputs - mean) / std
    print(f"y \n ({y.size()}) = \n {y}")
    out = self.gamma * y  + self.beta
    print(f"out \n ({out.size()}) = \n {out}")
    return out

In [17]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length,batch_size,embedding_dim)

print(f'input \n {inputs.size()}\n{inputs}')

input 
 torch.Size([5, 3, 8])
tensor([[[-0.0186,  0.7067,  1.0500,  0.7115,  0.0642, -0.8604,  0.1520,
          -0.1059],
         [-0.2389,  0.4846,  0.7832,  0.2635, -2.3290, -0.6923, -0.5436,
           0.6404],
         [ 1.6757, -0.0886, -1.8967, -1.8074,  1.1558,  0.9557,  0.4913,
           0.6863]],

        [[-0.5653,  2.0701, -0.0212, -0.1861, -0.3762, -1.4399,  0.5570,
          -0.8231],
         [-0.5038, -1.8305, -1.0819, -0.3513, -1.4983, -0.1878,  0.1926,
           0.2190],
         [ 0.0149, -0.0478,  0.5331, -0.4340, -1.2153, -0.2537, -0.1833,
          -0.5333]],

        [[ 0.9365, -1.3430,  1.3763,  2.0535, -0.9733,  1.6686,  0.1056,
          -0.7949],
         [-1.0997,  0.3988,  1.0758, -1.1210, -0.4448,  0.2521,  0.2425,
           0.2141],
         [ 0.1275,  0.2695, -1.3084,  1.1996, -0.0646, -0.4031, -0.2442,
           0.7629]],

        [[ 0.7058, -0.5934,  0.5076,  0.4231, -0.8268, -0.1855,  0.0362,
           0.5244],
         [-0.1671, -0.1527,  1.042

In [18]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [19]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.2124],
         [-0.2040],
         [ 0.1465]],

        [[-0.0981],
         [-0.6303],
         [-0.2649]],

        [[ 0.3787],
         [-0.0603],
         [ 0.0424]],

        [[ 0.0739],
         [ 0.0768],
         [-0.0249]],

        [[-0.2303],
         [-0.0177],
         [-0.4468]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.5619],
         [0.9520],
         [1.2485]],

        [[0.9852],
         [0.7147],
         [0.4708]],

        [[1.2258],
         [0.7176],
         [0.7101]],

        [[0.5289],
         [0.7952],
         [0.7926]],

        [[0.9893],
         [0.9807],
         [0.5643]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.4111,  0.8796,  1.4904,  0.8881, -0.2638, -1.9092, -0.1075,
          -0.5665],
         [-0.0366,  0.7233,  1.0370,  0.4911, -2.2320, -0.5129, -0.3567,
           0.8869],
         [ 1.2248, -0.1883, -1.6365, -1.5650,  0.8084,  0.6481,  0.2762,
           0.4323]],



In [20]:
out[0].mean(),out[0].std()

(tensor(-2.4835e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))