In [1]:
import numpy as np
import math
import torch
import torch.nn as nn


In [3]:
inputs = torch.Tensor([[[0.2,0.1,0.3],[0.5,0.1,0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [4]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i + 1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [6]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size(), mean

(torch.Size([2, 1, 1]),
 tensor([[[0.2000]],
 
         [[0.2333]]]))

In [8]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [11]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [12]:
output = gamma * y + beta
output

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [13]:
EPSILON = 1e-5
class LayerNormalization(nn.Module):
    
    def __init__(self, parameter_shape):
        super().__init__()
        self.parameter_shape = parameter_shape
        self.gamma = nn.Parameter(torch.ones(parameter_shape))
        self.beta = nn.Parameter(torch.zeros(parameter_shape))
        
    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameter_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean)**2).mean(dim=dims, keepdim=True)
        std = (var + EPSILON).sqrt()
        y = (inputs - mean)/std
        output = self.gamma * y + self.beta
        return output
    

In [14]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)
layer_norm = LayerNormalization(inputs.size()[-2:])
out = layer_norm.forward(inputs)
out

tensor([[[ 0.5019,  0.5692,  1.0045, -0.1776,  0.6505, -1.2011, -1.3317,
          -0.3965],
         [ 0.0139,  0.0189, -1.2381, -1.2170, -0.0283,  2.0039,  0.7841,
          -1.8420],
         [ 0.9267,  1.5421, -0.9312, -0.7507,  0.4604,  1.1893, -1.0179,
           0.4667]],

        [[ 0.3465,  0.2655,  1.5916, -1.0523, -0.2505, -1.1716,  1.7419,
          -0.6568],
         [ 1.2620,  0.9945,  0.3676,  2.1013, -1.0636,  0.0699, -0.7911,
          -0.5464],
         [ 0.0454, -0.7463,  0.7628, -1.5027, -0.7360,  0.7448, -0.8616,
          -0.9148]],

        [[-1.1033, -0.5286,  0.6993,  2.3888,  1.0329, -1.3896, -0.0937,
          -1.3551],
         [-0.2563,  1.0367, -0.3741,  0.7118, -0.1101,  0.0823,  0.0869,
          -1.7167],
         [-1.7790,  0.6672, -0.3518, -0.4921,  1.4216,  0.3318,  0.5153,
           0.5758]],

        [[ 2.1925, -2.2792,  0.2460, -0.7205,  0.4874, -0.8875, -0.3072,
           0.1986],
         [ 0.9012, -0.9346,  1.3415,  1.4760,  0.7331, -0.5098, 