In [2]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    """
    This class is designed for generating positions encodings.
    The PositionalEncoding class will inherit the nn.Module class.

    """

    def __init__(self, d_model, max_sequence_length):
        """
        This is the constructor method for the positional encoding
        class. It initializes the object when an instance of the
        class is created.

        Parameters
        ----------
        d_model : intger
            dimension of the positional encoding.
        max_sequence_length : integer
            maximum sequence length

        """

        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):

        """
        This method calculates and returns the positional encodings
        as per the mathematical formulation.


        Returns
        -------
        PE: tensor
            tensor of dimensions d_model x max_sequence_length

        """
        even_i = torch.arange(0, self.d_model, 2).float()
        denominator = torch.pow(10000, even_i/self.d_model)
        position = torch.arange(self.max_sequence_length).reshape(self.max_sequence_length, 1)
        even_position_encoding = torch.sin(position / denominator)
        odd_position_encoding = torch.cos(position / denominator)
        stacked = torch.stack([even_position_encoding, odd_position_encoding], dim=2)
        PE = torch.flatten(stacked, start_dim=1, end_dim=2)
        return PE

In [3]:
positional_encoding = PositionalEncoding(d_model=6, max_sequence_length=10)
positional_encoding.forward()

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],
        [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],
        [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],
        [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],
        [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],
        [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],
        [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],
        [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],
        [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]])