In [19]:
import torch
from torch import nn

In [20]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [21]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [22]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [23]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [24]:
dims

[-1, -2]

In [25]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [26]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [27]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [28]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [29]:
out = gamma * y + beta

In [30]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

# **LN Class**

In [31]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [32]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-5.4975e-02,  5.5059e-01,  1.9035e+00, -2.5530e-02, -1.1070e-01,
          -6.4873e-01, -1.8326e-01,  4.7748e-01],
         [ 1.1644e+00, -3.7115e-01,  1.4769e+00,  8.3417e-01, -9.3825e-01,
           1.4302e+00, -1.0945e+00, -9.3712e-01],
         [-6.6506e-01,  1.2799e+00,  3.1502e-01,  8.9894e-01,  9.4523e-01,
           3.4939e-01,  3.8966e-01,  4.9867e-01]],

        [[ 7.0444e-01, -1.3195e+00, -1.7230e+00,  1.3689e+00, -1.0682e+00,
           4.0957e-02, -3.4629e-02,  2.2270e-01],
         [-1.0340e+00, -1.7327e+00, -1.2065e-03,  1.1593e+00,  1.6773e+00,
           6.0251e-01,  3.0296e-01, -1.4294e-01],
         [-4.4778e-01,  2.0294e-01,  4.7119e-01, -1.4540e+00, -7.8283e-01,
           2.9834e-01,  8.6286e-02,  1.4181e-01]],

        [[ 1.6875e-01,  1.3425e+00, -1.7077e+00,  2.5511e+00, -3.7425e-01,
          -2.4037e+00,  5.2075e-01, -1.8169e-01],
         [ 1.2385e+00, -1.2919e+00, -3.2251e-01,  4.6074e-01, -2.0185e-01,
          

In [33]:
layer_norm = LayerNormalization(inputs.size()[-1:])


In [34]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.2386],
         [ 0.1956],
         [ 0.5015]],

        [[-0.2260],
         [ 0.1039],
         [-0.1855]],

        [[-0.0105],
         [-0.2041],
         [ 0.3568]],

        [[-0.0789],
         [-0.0992],
         [ 0.1212]],

        [[-0.0232],
         [ 0.3845],
         [ 0.3445]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.7220],
         [1.0646],
         [0.5468]],

        [[0.9915],
         [1.0379],
         [0.6156]],

        [[1.4780],
         [0.7801],
         [1.0535]],

        [[0.6497],
         [1.3018],
         [0.9597]],

        [[0.5170],
         [0.8197],
         [1.0750]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.4066,  0.4322,  2.3062, -0.3658, -0.4838, -1.2290, -0.5843,
           0.3310],
         [ 0.9101, -0.5324,  1.2036,  0.5999, -1.0650,  1.1597, -1.2118,
          -1.0640],
         [-2.1333,  1.4235, -0.3410,  0.7269,  0.8115, -0.2781, -0.2045,
          -0.0051]],



In [35]:
out[0].mean(), out[0].std()

(tensor(3.9736e-08, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))