In [2]:
import torch
from bokeh.util.terminal import dim
from torch import nn

$$
\mu = \frac{1}{d}\sum_{i=1}^{d} x_i,\quad
\sigma = \sqrt{\frac{1}{d}\sum_{i=1}^{d}(x_i-\mu)^2+\varepsilon},\quad
y = \gamma \frac{x-\mu}{\sigma+\varepsilon} + \beta
$$

In [3]:
class CustomLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        """
        Custom implementation of Layer Normalization

        Parameters:
        - normalized_shape: feature dimensions to be normalized
        - eps: small constant for numerical stability to avoid division by zero
        - elementwise_affine: whether to use learnable scale (γ) and shift (β) parameters
        """
        super().__init__()

        # Check whether normalized_shape is an int or a tuple
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            # Learnable scale parameter γ, initialized to 1
            self.gamma = nn.Parameter(torch.ones(*self.normalized_shape))
            # Learnable shift parameter β, initialized to 0
            self.beta = nn.Parameter(torch.zeros(*self.normalized_shape))
        else:
            # Do not use learnable parameters
            self.register_parameter('gamma', None)
            self.register_parameter('beta', None)

    def forward(self, x):
        # Compute mean and variance
        mean = x.mean(dim=-1, keepdim=True)

        # Biased variance estimation (dividing by n)
        # Variance formulas:
        # Biased: (1/n) * ((x1-mean)**2 + (x2-mean)**2 + ... + (xn-mean)**2)
        # Unbiased: (1/(n-1)) * ((x1-mean)**2 + (x2-mean)**2 + ... + (xn-mean)**2)
        variance = x.var(dim=-1, keepdim=True, unbiased=False)

        # Normalization formula
        x_normalized = (x - mean) / torch.sqrt(variance + self.eps)

        # Apply learnable parameters
        if self.elementwise_affine:
            output = self.gamma * x_normalized + self.beta
        else:
            output = x_normalized

        return output

In [5]:
# Create input data: (batch_size, seq_len, features)
batch_size, seq_len, features = 16, 10, 512
x = torch.randn(batch_size, seq_len, features)

print("Input shape:", x.shape)
print("Output shape:", x.shape)
print("Output mean:", x.mean(dim=-1))
print("Output variance:", x.var(dim=-1, unbiased=False))

Input shape: torch.Size([16, 10, 512])
Output shape: torch.Size([16, 10, 512])
Output mean: tensor([[-0.0428, -0.1191,  0.0106,  0.0515, -0.0009,  0.0283,  0.0030, -0.0164,
          0.0352,  0.0226],
        [-0.0105, -0.0069, -0.0636,  0.0281,  0.0270,  0.0480,  0.0503, -0.0773,
         -0.0629, -0.0011],
        [ 0.0155,  0.0082,  0.0065,  0.0031, -0.0079, -0.0725, -0.0073,  0.0387,
          0.0152, -0.0176],
        [ 0.0451,  0.0473, -0.0177, -0.0199, -0.0386, -0.0192,  0.0093,  0.0217,
          0.0896, -0.0568],
        [-0.0107, -0.0661, -0.0542, -0.0853, -0.0643, -0.0033,  0.0394,  0.0396,
         -0.0225, -0.0367],
        [-0.0533,  0.0006, -0.0409,  0.0179,  0.0194,  0.0383, -0.0092, -0.0456,
          0.0392,  0.0621],
        [-0.0260,  0.0313,  0.0134, -0.0275,  0.0539,  0.0291, -0.0080,  0.0757,
         -0.0484, -0.0394],
        [ 0.0757, -0.0352,  0.0891, -0.0484,  0.0247,  0.0398, -0.0232, -0.0031,
         -0.0483,  0.0014],
        [-0.0440,  0.0032, -0.0549, 

In [6]:
# Instantiate the custom Layer Normalization module
custom_norm = CustomLayerNorm(features)

# Forward pass
output = custom_norm(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("Output mean:", output.mean(dim=-1))
print("Output variance:", output.var(dim=-1, unbiased=False))

Input shape: torch.Size([16, 10, 512])
Output shape: torch.Size([16, 10, 512])
Output mean: tensor([[ 1.3039e-08, -2.4214e-08,  1.4435e-08,  1.3737e-08,  1.0245e-08,
         -6.9849e-09,  5.5879e-09, -3.7253e-09,  2.0489e-08, -7.4506e-09],
        [ 5.5879e-09, -7.4506e-09, -4.6566e-09, -6.0536e-09,  7.4506e-09,
         -1.8626e-09, -1.1176e-08,  1.0245e-08, -1.3039e-08,  8.3819e-09],
        [ 1.1176e-08,  1.4901e-08, -5.5879e-09,  1.8626e-09, -1.8626e-09,
          8.3819e-09, -9.3132e-09, -6.5193e-09,  9.3132e-10, -7.4506e-09],
        [ 8.8476e-09,  1.0245e-08,  2.1420e-08,  1.8626e-09, -1.8626e-09,
         -7.4506e-09, -3.7253e-09,  3.7253e-09, -5.5879e-09,  0.0000e+00],
        [-9.3132e-10,  1.0245e-08,  7.4506e-09,  1.3039e-08, -7.6834e-09,
          7.4506e-09,  2.9802e-08, -1.2107e-08,  9.3132e-10,  7.4506e-09],
        [ 1.7928e-08, -1.3039e-08, -3.7253e-09,  6.5193e-09, -4.6566e-10,
         -6.0536e-09,  7.4506e-09,  1.8626e-09,  1.8161e-08, -1.1176e-08],
        [ 0.00