In [1]:
import torch
from torch import nn

In [4]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [5]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [6]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [7]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [8]:
dims

[-1, -2]

In [9]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [10]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [11]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [12]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [13]:
out = gamma * y + beta

In [14]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [16]:
import torch
from torch import nn

class LayerNormalization():
  def __init__(self, parameter_shape, eps=1e-5) -> None:
    self.parameter_shape = parameter_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameter_shape))
    self.beta = nn.Parameter(torch.zeros(parameter_shape))

  def forward(self, input):
    dims = [-(i+1) for i in range(len(self.parameter_shape))]
    print(f"Mean \n ({mean.size()}): \n {mean}")
    var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
    std = (var + self.eps).sqrt()
    print(f"Standard Deviation \n ({std.size()}): \n {std}")
    y = (inputs - mean) / std
    print(f"y \n ({y.size()}) = \n {y}")
    out = self.gamma * y  + self.beta
    print(f"out \n ({out.size()}) = \n {out}")
    return out

In [18]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")
     

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.5279,  1.1531, -0.6580,  0.7171,  1.9809,  1.6145,  1.8791,
           0.6054],
         [-1.8898, -0.1667, -0.0573, -0.3673, -1.3410, -0.0506, -1.4536,
           1.1558],
         [ 1.1351,  1.1780,  0.3772,  1.3130, -0.8884, -0.4953,  0.6955,
          -1.8598]],

        [[-0.7588,  0.0780,  0.6861, -0.3111,  1.2979, -1.1623, -1.4031,
          -1.8136],
         [ 0.9753,  0.8771, -0.0466,  0.0878, -0.1609,  0.9619,  0.1215,
          -1.0098],
         [ 0.8472,  0.0972, -0.2107,  0.7850, -0.2391,  0.1192, -1.3483,
          -0.4069]],

        [[-0.6441, -1.2946,  1.2806, -2.5657,  0.3272,  0.9019,  0.0767,
           0.8301],
         [ 0.1343,  0.4219,  1.7110, -1.1128, -0.0116,  0.1387,  0.1799,
           0.2853],
         [-0.7093,  0.6286,  1.2543,  0.9525, -0.7191, -0.3856,  1.1232,
           0.2345]],

        [[-0.4925, -0.0510,  0.4828, -0.0206,  2.8138,  0.0149,  0.9851,
          -0.6275],
         [ 0.8363,  0.1461, 