In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
print(f"Inputs : {inputs}")
B, S, E = inputs.shape
inputs = inputs.reshape(S, B, E)
inputs.shape

Inputs : tensor([[[0.2000, 0.1000, 0.3000],
         [0.5000, 0.1000, 0.1000]]])


torch.Size([2, 1, 3])

In [3]:
inputs

tensor([[[0.2000, 0.1000, 0.3000]],

        [[0.5000, 0.1000, 0.1000]]])

In [4]:
parameter_shape = inputs.shape[-2:]

gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [5]:
gamma.shape, beta.shape

(torch.Size([1, 3]), torch.Size([1, 3]))

In [6]:
len(parameter_shape)

2

In [7]:
dims = [-(i+1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [8]:
mean = inputs.mean(dim=dims, keepdim=True)
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [9]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [10]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [11]:
out = gamma * y + beta
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

# All in One Class

In [12]:
import torch
import torch.nn as nn

class LayerNormalization(nn.Module):

    def __init__(self, parameter_shape, eps = 1e-5):
        super(LayerNormalization, self).__init__()
        self.parameter_shape = parameter_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameter_shape))
        self.beta = nn.Parameter(torch.zeros(parameter_shape))

    def forward(self, inputs):

        dims = [-(i + 1) for i in range(len(self.parameter_shape))]
        mean = inputs.mean(dim = dims, keepdim = True)
        # print(f"Mean : {mean}\n")
        print(f"Mean shape : {mean.shape}\n")

        var = ( (inputs - mean) ** 2 ).mean(dim=dims, keepdim=True)
        # print(f"Variance : {var}\n")
        print(f"Variance shape : {var.shape}\n")

        std = (var + self.eps).sqrt()
        # print(f"Standard Deviation : {std}\n")
        print(f"Standard Deviation shape : {std.shape}\n")

        y = (inputs - mean) / std
        # print(f"y : {y}\n")
        print(f"y shape : {y.shape}\n")

        out = self.gamma * y + self.beta
        # print(f"Out : {out}")
        print(f"Out shape : {out.shape}")

        return out



In [13]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"Input Size : {inputs.size()}")

Input Size : torch.Size([5, 3, 8])


In [14]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [15]:
out = layer_norm(inputs)

Mean shape : torch.Size([5, 3, 1])

Variance shape : torch.Size([5, 3, 1])

Standard Deviation shape : torch.Size([5, 3, 1])

y shape : torch.Size([5, 3, 8])

Out shape : torch.Size([5, 3, 8])


In [16]:
out[0].mean(), out[0].std()

(tensor(-3.7253e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))