In [1]:
import torch
import torch.nn as nn

In [2]:
outputs = torch.randn(3,6)
outputs

tensor([[ 0.3990,  0.3727,  1.7358, -1.7019, -0.0588,  0.1244],
        [ 0.6678,  0.8347,  0.5851,  1.1308, -0.3296,  0.4999],
        [ 0.7508,  0.5212,  1.5106, -0.6940, -1.0242,  0.1640]])

In [3]:
mean = outputs.mean( dim=-1, keepdim=True )
mean

tensor([[0.1452],
        [0.5648],
        [0.2048]])

In [4]:
sd = outputs.std( dim=-1, keepdim=True)
sd

tensor([[1.1046],
        [0.4916],
        [0.9406]])

In [5]:
normalized_outputs = (outputs - mean)/sd
normalized_outputs

tensor([[ 0.2298,  0.2059,  1.4401, -1.6723, -0.1847, -0.0188],
        [ 0.2096,  0.5491,  0.0413,  1.1513, -1.8193, -0.1320],
        [ 0.5805,  0.3365,  1.3884, -0.9555, -1.3066, -0.0433]])

In [6]:
normalized_outputs.mean( dim=-1, keepdim=True)

tensor([[0.0000e+00],
        [3.9736e-08],
        [1.9868e-08]])

In [7]:
normalized_outputs.std( dim=-1, keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000]])

In [8]:
torch.set_printoptions( sci_mode=False )

In [9]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift