In [1]:
import torch
from torch import nn

In [3]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [5]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [7]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [9]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [11]:
dims

[-1, -2]

In [13]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [15]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [17]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [19]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [21]:
out = gamma * y + beta

In [23]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

# Class

In [26]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [30]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.5929,  0.8070,  1.2866, -0.7476,  1.6183, -2.0066, -0.5785,
          -0.6217],
         [-1.0465, -0.1537,  0.7020,  0.5350, -0.6339, -0.5315, -0.4200,
          -0.7636],
         [ 2.2307,  0.1536, -1.3108,  0.2446, -0.2525,  0.4528, -0.4131,
           0.2782]],

        [[ 0.5693, -0.1028, -2.2118,  1.2083,  0.2966,  0.0773, -0.7978,
           0.2959],
         [-0.4342,  0.0784, -0.2928, -0.3826, -0.4195,  0.2838, -0.2065,
           0.8414],
         [-0.2832, -0.2807, -0.6178,  1.1359, -0.9651,  1.6227,  0.4663,
           0.3503]],

        [[ 0.1048,  1.1746, -1.5787, -0.1117, -0.6763,  0.1312,  0.1596,
           0.1958],
         [ 0.4526,  0.6923, -0.3785, -0.3931, -0.5342,  3.0932,  1.9361,
          -0.4073],
         [-0.0826,  1.9822, -1.5889,  0.2188, -1.5652, -0.0931, -1.8068,
          -0.6986]],

        [[-1.7234, -1.6653,  0.3319, -0.2120,  0.9567, -1.8425, -0.7789,
          -0.1924],
         [-0.1095, -0.8203, 

In [32]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [34]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[-0.1044],
         [-0.2890],
         [ 0.1729]],

        [[-0.0831],
         [-0.0665],
         [ 0.1785]],

        [[-0.0751],
         [ 0.5576],
         [-0.4543]],

        [[-0.6407],
         [-0.1429],
         [-0.3596]],

        [[-0.0153],
         [-0.0275],
         [ 0.0459]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.1454],
         [0.5781],
         [0.9391]],

        [[0.9650],
         [0.4181],
         [0.8277]],

        [[0.7407],
         [1.2380],
         [1.1762]],

        [[0.9726],
         [0.5155],
         [0.9870]],

        [[0.9808],
         [0.6874],
         [0.6612]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.4264,  0.7957,  1.2144, -0.5615,  1.5040, -1.6607, -0.4139,
          -0.4516],
         [-1.3104,  0.2341,  1.7143,  1.4254, -0.5965, -0.4194, -0.2265,
          -0.8210],
         [ 2.1912, -0.0205, -1.5800,  0.0763, -0.4531,  0.2980, -0.6240,
           0.1121]],



In [36]:
out[0].mean(), out[0].std()

(tensor(4.9671e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))