In [9]:
import torch
import torch.nn as nn

In [None]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3],[0.5,0.1,0.1]]])
B,S,E = inputs.size() #batch, sequence, embedding
inputs = inputs.reshape(S,B,E) # done as transformer expects sequence first
inputs.size()

torch.Size([2, 1, 3])

In [12]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [None]:
len(parameter_shape) # number of dimensions to normalize over
parameter_shape #shape for gamma and beta

(2, torch.Size([1, 3]))

In [13]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [16]:
dims = [-(i+1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [18]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size(), mean

(torch.Size([2, 1, 1]),
 tensor([[[0.2000]],
 
         [[0.2333]]]))

In [19]:
var = ((inputs - mean)**2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [21]:
y = (inputs - mean) / std
y.size(), y

(torch.Size([2, 1, 3]),
 tensor([[[ 0.0000, -1.2238,  1.2238]],
 
         [[ 1.4140, -0.7070, -0.7070]]]))

In [23]:
out = gamma * y + beta
out.size(), out

(torch.Size([2, 1, 3]),
 tensor([[[ 0.0000, -1.2238,  1.2238]],
 
         [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>))

In [36]:
class LayerNormalization(nn.Module):
    def __init__(self, parameter_shape, eps=1e-5):
        super().__init__()
        self.parameter_shape = parameter_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameter_shape))
        self.beta = nn.Parameter(torch.zeros(parameter_shape))
        
    def forward(self,x): # x is input tensor
        dims = [-(i+1) for i in range(len(self.parameter_shape))]
        mean = x.mean(dim=dims, keepdim=True)
        var = ((x - mean)**2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (x - mean) / std
        out = self.gamma * y + self.beta
        return out

In [37]:
batch_size = 3    
sentence_length = 5
embedding_dim = 8

x = torch.randn(sentence_length, batch_size, embedding_dim)
x.size(),x


(torch.Size([5, 3, 8]),
 tensor([[[ 0.9808,  1.1174, -1.1774,  0.4539, -1.3003, -0.5163, -0.4770,
            1.9120],
          [-0.6183,  0.0120, -1.4875,  1.0908,  0.0713, -0.0275,  1.6774,
           -1.7446],
          [ 0.2669, -1.0810,  0.7801, -1.5209,  0.2631, -1.2286,  0.2108,
            1.7517]],
 
         [[-1.7932,  1.5784,  0.0443, -2.3657, -0.5700, -0.5194,  0.2626,
            0.0749],
          [-0.3007, -0.2718, -0.1874,  0.7200,  0.2329,  2.0341, -0.8840,
           -1.2228],
          [-0.8771, -0.2332,  1.7321,  1.3999, -0.3612, -0.3592,  1.1140,
           -0.4511]],
 
         [[-0.1374,  2.3371, -0.1653,  0.7964, -1.3572, -0.5025,  1.4807,
           -0.2293],
          [ 1.1909, -0.0310,  0.5169,  0.4907,  0.1419,  0.9483, -0.5425,
            0.2517],
          [ 1.7080,  0.0215,  0.1316,  0.7889, -1.4767,  0.8710, -0.4763,
            0.5713]],
 
         [[-0.7740,  1.5294, -0.7550,  1.5821,  1.3493, -1.2489,  0.4472,
           -1.2390],
          [-1.246

In [39]:
layer_norm = LayerNormalization(x.size()[-2:])
out = layer_norm.forward(x)
out.size(), out

(torch.Size([5, 3, 8]),
 tensor([[[ 0.9287,  1.0549, -1.0648,  0.4420, -1.1783, -0.4541, -0.4178,
            1.7888],
          [-0.5484,  0.0338, -1.3512,  1.0303,  0.0886, -0.0026,  1.5721,
           -1.5887],
          [ 0.2693, -0.9758,  0.7433, -1.3820,  0.2658, -1.1121,  0.2174,
            1.6408]],
 
         [[-1.6560,  1.5473,  0.0897, -2.1999, -0.4939, -0.4458,  0.2971,
            0.1188],
          [-0.2380, -0.2106, -0.1304,  0.7317,  0.2689,  1.9802, -0.7923,
           -1.1141],
          [-0.7857, -0.1739,  1.6933,  1.3776, -0.2955, -0.2936,  1.1060,
           -0.3809]],
 
         [[-0.5043,  2.3142, -0.5361,  0.5593, -1.8937, -0.9202,  1.3387,
           -0.6091],
          [ 1.0087, -0.3831,  0.2409,  0.2111, -0.1862,  0.7324, -0.9658,
           -0.0611],
          [ 1.5977, -0.3233, -0.1980,  0.5508, -2.0298,  0.6443, -0.8903,
            0.3029]],
 
         [[-0.7451,  1.6823, -0.7250,  1.7378,  1.4926, -1.2456,  0.5419,
           -1.2352],
          [-1.242