# Fast Decaying Positional Encoding

$$ {PE}_{pos, 2i} = \frac{2i}{d_{\text{model}}} \cdot \frac{\sin(pos)}{\sqrt{pos}}$$
$$ {PE}_{pos, 2i+1} = \frac{2i+1}{d_{\text{model}}} \cdot \frac{\cos(pos)}{\sqrt{pos}}$$

In [2]:
import torch
from torch import nn


def has_duplicate_rows(matrix):
    """
    To check if there are same row in the given matrix
    :param matrix: torch.Tensor
    :return: True if given matrix has duplicate rows, False if not.
    """
    unique_rows = torch.unique(matrix, dim=1)
    return unique_rows.shape[1] < matrix.shape[1]


class FastDecayingPositionalEncoding(nn.Module):
    """
    Fast Decaying Positional Encoding
    """
    def __init__(self, d_model: int = 512, dropout: float = 0, max_len: int = 200):
        super(FastDecayingPositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.unsqueeze(torch.arange(start=1, end=max_len + 1), dim=1)  # Shape is torch.Size([max_len, 1])

        term = torch.exp(
            torch.log(torch.arange(start=1, end=d_model + 1, step=2)) -
            torch.log(torch.Tensor([d_model])) -
            (0.5 * position)
        )  # Shape is torch.Size([max_len, d_model / 2])

        pe[:, 0::2] = term * torch.sin(position)  # pe_pos_2i
        pe[:, 1::2] = term * torch.cos(position)  # pe_pos_2i+1

        pe = torch.unsqueeze(pe, dim=0)  # Shape is torch.Size([1, max_len, d_model])

        if has_duplicate_rows(pe):
            raise Warning(
                'The positional encoding matrix contains invalid encoding rows, which may result in the loss of positional information. Please reduce the value of the "max_len" parameter.')

        self.register_buffer(name='pe', tensor=pe)  # Buffer

    def forward(self, x: torch.Tensor):
        # x.shape = torch.Size([max_len, d_model])
        x = x + self.pe[:, :x.size(1)].requires_grad_(False)  # Prevent from calculating gradients.
        return self.dropout(x)  # Stochastically dropout elements


pe = FastDecayingPositionalEncoding().pe
print(f'Shape is {pe.shape}')
print(pe)

Shape is torch.Size([1, 200, 512])
tensor([[[ 9.9683e-04,  6.4006e-04,  2.9905e-03,  ...,  3.2579e-01,
           5.0938e-01,  3.2707e-01],
         [ 6.5334e-04, -2.9901e-04,  1.9600e-03,  ..., -1.5219e-01,
           3.3386e-01, -1.5279e-01],
         [ 6.1500e-05, -4.3144e-04,  1.8450e-04,  ..., -2.1960e-01,
           3.1427e-02, -2.2047e-01],
         ...,
         [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -1.0089e-43,
          -8.4078e-45, -1.0089e-43],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -2.9427e-44,
          -5.4651e-44, -2.9427e-44],
         [-0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  1.8217e-44,
          -3.2230e-44,  1.8217e-44]]])
