In [63]:
import torch
import torch.nn as nn

In [64]:
inputs = torch.tensor([[[0.2, 0.1, 0.6, 0.8, 1.0], [0.5, 0.7, 1.2, 0.9, 0.2]]])
batch_size, sequence_length, embedding_dim = inputs.size()
inputs = inputs.reshape(sequence_length, batch_size, embedding_dim)
inputs.size()

torch.Size([2, 1, 5])

In [65]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape)) 
beta = nn.Parameter(torch.zeros(parameter_shape))

In [66]:
gamma.size(), beta.size()

(torch.Size([1, 5]), torch.Size([1, 5]))

In [67]:
dims = [-(i+1) for i in range(len(parameter_shape))]

In [68]:
dims

[-1, -2]

In [69]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [70]:
mean

tensor([[[0.5400]],

        [[0.7000]]])

In [71]:
var = ((inputs - mean)**2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.3441]],

        [[0.3406]]])

In [72]:
y = (inputs - mean)/std

In [73]:
y

tensor([[[-0.9881, -1.2787,  0.1744,  0.7556,  1.3368]],

        [[-0.5872,  0.0000,  1.4680,  0.5872, -1.4680]]])

In [74]:
out = gamma*y + beta
out

tensor([[[-0.9881, -1.2787,  0.1744,  0.7556,  1.3368]],

        [[-0.5872,  0.0000,  1.4680,  0.5872, -1.4680]]],
       grad_fn=<AddBackward0>)

In [94]:
class LayerNormalization():
    def __init__(self, parameter_shape, epsilon=1e-5):
        self.parameters_shape = parameter_shape
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(parameter_shape)) 
        self.beta = nn.Parameter(torch.zeros(parameter_shape)) 
        
    def forward(self):
        dims = [-(i+1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean)**2).mean(dim=dims, keepdim=True)
        std = (var + self.epsilon).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean)/std
        print(f"y \n ({y.size()}): \n {y}") 
        out = self.gamma*y + beta
        print(f"out \n ({out.size()}): \n {out}")
        return out

In [95]:
batch_size=3
sequence_length = 5
embedding_size = 8

inputs = torch.randn(sequence_length, batch_size, embedding_dim)
print(f"inputs \n ({inputs.size()}) = \n {inputs}")

inputs 
 (torch.Size([5, 3, 5])) = 
 tensor([[[-0.6284, -1.4106, -0.5916,  0.6018,  0.5532],
         [ 1.5317,  1.1787,  0.1882,  0.4501, -1.6880],
         [-0.2660, -0.2961,  0.4096,  1.1616,  2.0313]],

        [[-0.9347,  0.2804, -1.0159, -0.8371, -2.1791],
         [-0.0864, -0.7040, -0.6551, -1.2941, -0.0737],
         [ 0.6322,  1.7580,  0.2413,  0.3812,  0.2824]],

        [[-0.6021, -0.7671,  0.6160,  0.5653,  0.2687],
         [ 1.7863, -2.2471, -0.2779,  0.0600, -0.8889],
         [ 0.5550,  0.1710,  0.6645, -0.1573,  0.2674]],

        [[ 0.2417,  0.5487, -0.5366, -0.8463,  2.2895],
         [-1.4366,  0.5802,  0.8936, -1.9977, -1.1401],
         [-0.9619,  1.9332, -0.2121, -0.5633, -0.9769]],

        [[ 1.9679,  0.3920,  0.2808,  0.0888,  0.8088],
         [-1.0344,  0.7729,  0.3008,  0.6770, -0.1856],
         [-0.3712, -1.0949, -1.4734, -0.5244,  0.0584]]])


In [96]:
layer_norm = LayerNormalization(inputs.size()[-2:])

In [98]:
out = layer_norm.forward()

Mean 
 (torch.Size([5, 1, 1])): 
 tensor([[[ 0.2150]],

        [[-0.2803]],

        [[ 0.0009]],

        [[-0.1456]],

        [[ 0.0442]]])
Standard Deviation 
 (torch.Size([5, 1, 1])): 
 tensor([[[1.0111]],

        [[0.9180]],

        [[0.8835]],

        [[1.1763]],

        [[0.8475]]])
y 
 (torch.Size([5, 3, 5])): 
 tensor([[[-0.8342, -1.6078, -0.7978,  0.3825,  0.3344],
         [ 1.3022,  0.9531, -0.0266,  0.2325, -1.8822],
         [-0.4758, -0.5056,  0.1925,  0.9363,  1.7964]],

        [[-0.7128,  0.6108, -0.8013, -0.6065, -2.0684],
         [ 0.2113, -0.4616, -0.4082, -1.1044,  0.2251],
         [ 0.9940,  2.2204,  0.5682,  0.7206,  0.6130]],

        [[-0.6825, -0.8693,  0.6962,  0.6388,  0.3031],
         [ 2.0209, -2.5444, -0.3156,  0.0668, -1.0071],
         [ 0.6272,  0.1925,  0.7511, -0.1790,  0.3016]],

        [[ 0.3293,  0.5903, -0.3324, -0.5957,  2.0701],
         [-1.0974,  0.6170,  0.8834, -1.5744, -0.8454],
         [-0.6939,  1.7672, -0.0565, -0.3550, -0.7