In [1]:
import torch
from torch import nn

class PositionwiseNormalization():
    """

    This is a class to carry out PositionwiseNormalization along the feature
    dimension of the word embeddings. It is done to ensure that the values
    are consistent and hence do not affect upstream process in feed forward
    network.

    """
    def __init__(self, parameters_shape, eps=1e-5):
        """
        The constructor (__init__) initializes the PositionwiseNormalization
        object with the following parameters:

        parameters_shape : integer
            This is the shape of the parameters (gamma and beta) used for
            normalization. It specifies the dimensions over which normalization
            will be applied.

        eps : decimal
            This is a small constant added to the denominator to prevent
            division by zero (avoiding numerical instability).

        Inside the constructor, the class initializes two learnable parameters:

        self.gamma:
            It's initialized as a learnable parameter with ones,
            meaning it starts with no scaling (identity operation).

        self.beta:
            It's initialized as a learnable parameter with zeros,
            meaning it starts with no shift.

        """

        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        """
        The forward function is where the actual Layer Normalization is
        applied to the input tensor.

        Parameters
        ----------
        input : tensor
            This is a tensor of word embeddings

        Returns
        -------
        out   : tensor
            This is a tensor of normalized output
        """
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [7]:
batch_size = 3
sentence_length = 5
embedding_dim = 3
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

In [8]:
check = inputs.size()[-1:]

In [9]:
check

torch.Size([3])

In [10]:
positionwise_norm = PositionwiseNormalization(inputs.size()[-1:])

In [11]:
output = positionwise_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.1668],
         [-0.2364],
         [ 0.1095]],

        [[ 0.5149],
         [-0.7531],
         [ 0.3106]],

        [[-0.9827],
         [ 0.1498],
         [ 0.5688]],

        [[-0.8799],
         [-0.0753],
         [ 0.2611]],

        [[-0.0795],
         [ 0.2803],
         [-0.2529]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.1969],
         [0.5374],
         [2.0000]],

        [[0.5590],
         [0.2884],
         [0.1639]],

        [[0.6091],
         [0.3854],
         [0.7997]],

        [[0.8282],
         [0.6363],
         [1.2907]],

        [[0.7299],
         [0.5181],
         [0.5707]]])
y 
 (torch.Size([5, 3, 3])) = 
 tensor([[[ 1.3403, -1.0609, -0.2794],
         [ 1.2630, -1.1825, -0.0805],
         [ 0.2595, -1.3337,  1.0742]],

        [[ 1.3953, -0.4984, -0.8970],
         [ 1.3147, -0.2062, -1.1085],
         [ 1.3893, -0.4670, -0.9223]],

        [[ 1.4082, -0.8169, -0.5912],
         [-