In [1]:

import torch
from torch import nn

In [2]:

inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:

parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [4]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [6]:

dims

[-1, -2]

In [7]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [8]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [9]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [10]:

y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [11]:

out = gamma * y + beta

In [12]:

out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

class

In [13]:

import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out
     

In [14]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.8423, -0.5826,  0.2249, -0.1391, -1.2348,  0.2448,  0.2864,
          -0.0456],
         [-0.6692, -0.8063, -0.4026, -0.5808,  0.1779, -1.6537,  0.5476,
          -0.2312],
         [-0.8894,  3.3253,  1.4620, -2.1106,  0.7160,  0.3866, -1.2333,
           0.4488]],

        [[ 0.3725,  0.2372,  1.4064,  0.9999,  0.9835,  0.1841, -0.2830,
          -1.7932],
         [ 1.1372,  1.7271, -1.3094, -1.7460, -0.4164,  1.8422, -1.7017,
          -0.9184],
         [ 0.6853,  0.1925, -1.2026, -1.4036, -0.0206,  0.0083, -1.9377,
           0.8763]],

        [[-0.4065, -0.4807, -1.0862, -0.4599, -0.4289,  0.2740, -0.4741,
          -1.1042],
         [-0.0955,  0.3661,  1.9871, -1.7051,  0.5756,  1.1697,  1.9461,
           0.4230],
         [ 0.1214, -0.2026, -0.0295,  0.0292,  0.9460,  1.5448,  1.2466,
          -0.5285]],

        [[ 0.4787,  2.8194, -0.5490,  0.4843, -1.8543, -0.0533, -0.4118,
          -2.0390],
         [-0.3916,  1.1560, 

In [15]:

layer_norm = LayerNormalization(inputs.size()[-1:])

In [16]:

out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[-0.2610],
         [-0.4523],
         [ 0.2632]],

        [[ 0.2634],
         [-0.1732],
         [-0.3503]],

        [[-0.5208],
         [ 0.5834],
         [ 0.3909]],

        [[-0.1406],
         [ 0.3206],
         [ 0.2767]],

        [[ 0.0310],
         [ 0.3761],
         [-0.2641]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.5294],
         [0.6206],
         [1.5918]],

        [[0.9297],
         [1.4189],
         [0.9660]],

        [[0.4063],
         [1.1148],
         [0.7026]],

        [[1.4285],
         [0.6724],
         [0.8133]],

        [[1.0210],
         [0.8566],
         [1.0760]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.0979, -0.6073,  0.9177,  0.2303, -1.8392,  0.9554,  1.0339,
           0.4069],
         [-0.3496, -0.5705,  0.0801, -0.2071,  1.0155, -1.9360,  1.6113,
           0.3563],
         [-0.7240,  1.9237,  0.7531, -1.4912,  0.2845,  0.0776, -0.9401,
           0.1166]],



In [17]:


out[0].mean(), out[0].std()

(tensor(4.9671e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))