In [1]:
import torch
import torch.nn as nn

In [2]:
class LayerNormalization() :

  def __init__(self, parameters_shape, eps=1e-5) :
    self.parameters_shape = parameters_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameters_shape))
    self.beta = nn.Parameter(torch.zeros(parameters_shape))

  def forward(self, input) :
    dims = [ -(i+1) for i in range(len(self.parameters_shape)) ]
    mean = inputs.mean(dim = dims, keepdim = True)
    var = ( (inputs-mean)**2 ).mean(dim=dims, keepdim = True)
    std = (var + self.eps).sqrt()
    y = (inputs- mean) / std
    out = self.gamma * y + self.beta
    return out

In [3]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.1305,  0.7766,  0.2488, -0.7696,  0.3130,  0.9095, -1.0269,
           0.7433],
         [ 0.0445,  1.0059,  0.6081, -0.0995, -2.3801,  1.3989,  0.9593,
          -0.8575],
         [-0.4124,  0.7781,  0.3412,  0.1573, -1.0477,  0.2032, -0.0979,
           0.2657]],

        [[-1.5440,  2.1601,  0.7988,  0.3979, -0.2143,  1.3551, -1.5014,
          -0.1860],
         [ 0.5978,  1.2435,  0.7521, -1.6968, -0.2841,  0.7680,  0.9688,
          -0.6426],
         [ 0.3264, -0.8967, -0.1409,  0.9390,  0.2855, -3.0673, -1.4926,
          -0.7302]],

        [[ 0.1264,  0.8977, -0.5362, -1.4786, -0.1422, -0.3115,  1.4912,
          -1.5385],
         [-0.1230,  1.5032,  0.4378, -0.2139,  1.5059, -1.5982,  0.2533,
          -0.0602],
         [ 0.6296,  0.6556, -0.3112, -2.3175, -0.4679,  1.1303, -0.2192,
           1.0089]],

        [[ 0.0177, -0.7093,  2.2480,  1.0673,  1.5850,  0.4815, -2.6331,
           0.7035],
         [ 0.2420,  0.8561, 

In [4]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [5]:
out = layer_norm.forward(inputs)

In [6]:
out[0].mean(), out[0].std()

(tensor(-1.2418e-08, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))