# Layer Normalization in Transformers

In [1]:
import torch
from torch import nn

In [8]:
inputs=torch.Tensor([[[0.2,0.1,0.3],[0.5,0.1,0.1]]])
B,S,E=inputs.size()
inputs=inputs.reshape(S,B,E)
inputs.size()

torch.Size([2, 1, 3])

In [9]:
parameter_shape=inputs.size()[-2:]
gamma=nn.Parameter(torch.ones(parameter_shape))
beta=nn.Parameter(torch.zeros(parameter_shape))

In [11]:
gamma.size(),beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [12]:
dims=[-(i+1) for i in range(len(parameter_shape))]

In [13]:
dims

[-1, -2]

In [14]:
mean=inputs.mean(dim=dims,keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [15]:
var=((inputs-mean)**2).mean(dim=dims,keepdim=True)
epsilon=1e-5
std=(var+epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [16]:
y=(inputs-mean)/std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [17]:
out=gamma*y+beta

In [18]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

# class


In [21]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):
        even_i = torch.arange(0, self.d_model, 2).float()
        denominator = torch.pow(10000, even_i/self.d_model)
        position = torch.arange(self.max_sequence_length).reshape(self.max_sequence_length, 1)
        even_PE = torch.sin(position / denominator)
        odd_PE = torch.cos(position / denominator)
        stacked = torch.stack([even_PE, odd_PE], dim=2)
        PE = torch.flatten(stacked, start_dim=1, end_dim=2)
        return PE

In [22]:
pe = PositionalEncoding(d_model=6, max_sequence_length=10)
pe.forward()

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],
        [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],
        [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],
        [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],
        [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],
        [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],
        [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],
        [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],
        [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]])