In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])

In [6]:
inputs.size()

torch.Size([1, 2, 3])

In [7]:
B, S, E = inputs.size()

In [9]:
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [10]:
parameter_shape = inputs.size()[-2:]

In [11]:
parameter_shape

torch.Size([1, 3])

In [12]:
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [13]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [19]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [20]:
dims

[-1, -2]

In [23]:
mean = inputs.mean(dim = dims, keepdim = True)
mean, mean.size()

(tensor([[[0.2000]],
 
         [[0.2333]]]),
 torch.Size([2, 1, 1]))

In [24]:
var = ((inputs - mean) ** 2).mean(dim = dims, keepdim = True)
var

tensor([[[0.0067]],

        [[0.0356]]])

In [27]:
epsilon = 1e-5
std = (var + epsilon).sqrt()
std, std.shape

(tensor([[[0.0817]],
 
         [[0.1886]]]),
 torch.Size([2, 1, 1]))

In [28]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [29]:
out = gamma + y + beta
out

tensor([[[ 1.0000, -0.2238,  2.2238]],

        [[ 2.4140,  0.2930,  0.2930]]], grad_fn=<AddBackward0>)

In [31]:
class LayerNormalization(nn.Module):
  def __init__(self, parameter_shape, eps = 1e-5):
    super().__init__()
    self.parameter_shape = parameter_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameter_shape))
    self.beta = nn.Parameter(torch.zeros(parameter_shape))

  def forward(self, input):
    dims = [-(i + 1) for i in range(len(self.parameter_shape))]
    mean = inputs.mean(dim = dims, keepdim = True)
    print(f"Mean: {mean}")
    var = ((inputs - mean) ** 2).mean(dim = dims, keepdim = True)
    print(f"Variance: {var}")
    std = (var + self.eps).sqrt()
    print(f"Standard Deviation: {std}")
    y = (input - mean) / std
    print(f"y: {y}")
    out = self.gamma + y + self.beta
    print(f"Output: {out}")
    return out

In [32]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(batch_size, sentence_length, embedding_dim)

In [33]:
inputs.size()

torch.Size([3, 5, 8])

In [35]:
layer_norm = LayerNormalization(inputs.size()[-2:])

In [36]:
out = layer_norm.forward(inputs)

Mean: tensor([[[ 0.1025]],

        [[-0.0097]],

        [[-0.0390]]])
Variance: tensor([[[1.0868]],

        [[1.1108]],

        [[0.8458]]])
Standard Deviation: tensor([[[1.0425]],

        [[1.0540]],

        [[0.9197]]])
y: tensor([[[ 1.5185, -0.4671, -0.6410,  0.0300, -0.6721,  0.0434,  0.4537,
           0.0533],
         [-1.0392, -0.3043, -0.2463,  2.2190, -1.6740,  2.2868, -0.3453,
           0.3740],
         [-1.1315, -1.7738,  0.1885,  0.0312,  0.5953,  0.5943, -0.3127,
           1.4310],
         [ 1.7040, -0.4264, -2.0953,  0.8415,  0.4915,  0.3447, -0.3584,
           0.1684],
         [-0.3013,  1.2851, -0.0718, -0.9507,  0.7643, -0.7543, -0.4715,
          -1.3817]],

        [[ 0.6041,  0.2496, -0.6657,  0.1420,  0.4974, -0.0629, -0.5487,
           0.1565],
         [ 0.9447,  0.2794, -1.1507,  2.6832, -2.9042, -0.2152,  0.0098,
           0.2530],
         [-1.0519, -0.6256,  0.8330,  0.6377,  1.8624,  0.1823, -0.3917,
          -0.8698],
         [-1.1111, -1.2