In [6]:
from typing import Union, List

import torch
from torch import nn

In [28]:
class LayerNorm(nn.Module):
    
    """
    This is a PyTorch implementation of Layer Normalization.

    Layer normalization LN normalizes the input X as follows:

When input X∈RB×C is a batch of embeddings, where B is the batch size and C is the number of features. γ∈RC and β∈RC. LN(X)=γCVar​[X]+ϵ

​X−CE​[X]​+β

When input X∈RL×B×C is a batch of a sequence of embeddings, where B is the batch size, C is the number of channels, L is the length of the sequence. γ∈RC and β∈RC. LN(X)=γCVar​[X]+ϵ

​X−CE​[X]​+β

When input X∈RB×C×H×W is a batch of image representations, where B is the batch size, C is the number of channels, H is the height and W is the width. This is not a widely used scenario. γ∈RC×H×W and β∈RC×H×W. LN(X)=γC,H,WVar​[X]+ϵ
​X−C,H,WE​[X]​+β
    """
    def __init__(self,
                 normalized_shape = Union[int, List[int], torch.Size],
                 eps: float = 1e-5,
                 apply_affine: bool = True, ) -> None:

        """
        normalized_shape S is the shape of the elements (except the batch). The input should then be X∈R∗×S[0]×S[1]×...×S[n]
        eps is ϵ, used in Var[X]+ϵ for numerical stability
        apply_affine is whether to scale and shift the normalized value
        """
        super().__init__()

        self.normalized_shape = normalized_shape
        self.eps = eps
        self.apply_affine = apply_affine

        if apply_affine:
            self.scale = nn.Parameter(torch.ones(normalized_shape))
            self.shift = nn.Parameter(torch.zeros(normalized_shape))

    
    def forward(self, x: torch.Tensor):
        """
        x is a tensor of shape [*, S[0], S[1], ..., S[n]].
        * could be any number of dimensions.
        For example, in an NLP task this will be [seq_len, batch_size, features]
        """

        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]

        var, mean = torch.var_mean(x,
                                   dim = [-(i+1) for i in range(len(self.normalized_shape))],
                                   keepdim=True)

        x_hat = (x - mean) / torch.sqrt(var + self.eps)

        return self.scale*x_hat if self.apply_affine else x_hat
