# Positional Encoding

During positional encoding, we want to add information to each token in the context window based on it's position within the window. The model can then this information to learn potential relationships between tokens based on their relative positions. The authors propose to do this by defining $PE$, a ($\text{max\_seq\_len} \times d_{model})$ tensor containing positional information for each token and then adding $PE$ to the tensor representing the context window. $\it{PE}$ is defined according to the following formula:

$$
\begin{aligned}
PE_{(pos, 2i)} &= sin(pos / 10000^{2i / d_{model}}), \\
PE_{(pos, 2i+1)} &= cos(pos / 10000^{2i / d_{model}}) \\
\end{aligned}
$$
$$
\begin{aligned}
\text{where } pos =&\in \{ n \in \mathbb{Z} \mid 0 \leq n \leq \text{max\_seq\_len} \}, \\
                t =&\in \{ n \in \mathbb{Z} \mid 0 \leq n \leq d_{model} / 2 \}
\end{aligned}
$$

to make things easier to implement (or at least for me it's more obvious this way), we'll make a slight change to the formula:
$$
\it{PE}_{(pos, i)} =
\begin{cases}
& sin(pos / 10000^{2\lfloor\frac{i}{2}\rfloor / d_{model}}), & \text{if } i \text{ is even}, \\
& cos(pos / 10000^{2\lfloor\frac{i}{2}\rfloor / d_{model}}), & \text{if } i \text{ is odd}
\end{cases}
$$

In [3]:
import torch

In [5]:
d_model = 512

In [6]:
bias = torch.full((d_model,), 10_000, dtype=torch.float)
indices = ((torch.arange(d_model) // 2) * 2) / d_model
divisor = bias.pow(indices)
divisor

tensor([1.0000e+00, 1.0000e+00, 1.0366e+00, 1.0366e+00, 1.0746e+00, 1.0746e+00,
        1.1140e+00, 1.1140e+00, 1.1548e+00, 1.1548e+00, 1.1971e+00, 1.1971e+00,
        1.2409e+00, 1.2409e+00, 1.2864e+00, 1.2864e+00, 1.3335e+00, 1.3335e+00,
        1.3824e+00, 1.3824e+00, 1.4330e+00, 1.4330e+00, 1.4855e+00, 1.4855e+00,
        1.5399e+00, 1.5399e+00, 1.5963e+00, 1.5963e+00, 1.6548e+00, 1.6548e+00,
        1.7154e+00, 1.7154e+00, 1.7783e+00, 1.7783e+00, 1.8434e+00, 1.8434e+00,
        1.9110e+00, 1.9110e+00, 1.9810e+00, 1.9810e+00, 2.0535e+00, 2.0535e+00,
        2.1288e+00, 2.1288e+00, 2.2067e+00, 2.2067e+00, 2.2876e+00, 2.2876e+00,
        2.3714e+00, 2.3714e+00, 2.4582e+00, 2.4582e+00, 2.5483e+00, 2.5483e+00,
        2.6416e+00, 2.6416e+00, 2.7384e+00, 2.7384e+00, 2.8387e+00, 2.8387e+00,
        2.9427e+00, 2.9427e+00, 3.0505e+00, 3.0505e+00, 3.1623e+00, 3.1623e+00,
        3.2781e+00, 3.2781e+00, 3.3982e+00, 3.3982e+00, 3.5227e+00, 3.5227e+00,
        3.6517e+00, 3.6517e+00, 3.7855e+

In [7]:
max_seq_len = 1024
pos = torch.arange(max_seq_len, dtype=torch.float).view((-1, 1))
freqs = pos / divisor
freqs

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0000e+00, 9.6466e-01,  ..., 1.0746e-04, 1.0366e-04,
         1.0366e-04],
        [2.0000e+00, 2.0000e+00, 1.9293e+00,  ..., 2.1492e-04, 2.0733e-04,
         2.0733e-04],
        ...,
        [1.0210e+03, 1.0210e+03, 9.8492e+02,  ..., 1.0972e-01, 1.0584e-01,
         1.0584e-01],
        [1.0220e+03, 1.0220e+03, 9.8588e+02,  ..., 1.0982e-01, 1.0594e-01,
         1.0594e-01],
        [1.0230e+03, 1.0230e+03, 9.8685e+02,  ..., 1.0993e-01, 1.0605e-01,
         1.0605e-01]])

In [8]:
PE = torch.zeros((max_seq_len, d_model), dtype=torch.float)
PE[:, 0::2] = torch.sin(freqs[:, 0::2])
PE[:, 1::2] = torch.cos(freqs[:, 1::2])
PE

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,
          1.0366e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,
          2.0733e-04,  1.0000e+00],
        ...,
        [ 1.7612e-02, -9.9984e-01, -9.9954e-01,  ...,  9.9399e-01,
          1.0564e-01,  9.9440e-01],
        [-8.3182e-01, -5.5504e-01, -5.4457e-01,  ...,  9.9398e-01,
          1.0575e-01,  9.9439e-01],
        [-9.1649e-01,  4.0007e-01,  3.7906e-01,  ...,  9.9396e-01,
          1.0585e-01,  9.9438e-01]])

In [17]:
from torch import nn

class PositionalEncode(nn.Module):
    def __init__(self, max_seq_length: int = 1024, d_model: int = 512) -> None:
        """
        Generates sinusoidal positional encodings.
    
        Parameters:
            max_seq_len (int): Maximum sequence length.
            d_model (int): Dimensionality of the model embeddings.
    
        Returns:
            torch.Tensor: A tensor of shape (max_seq_len, d_model) containing 
                          the positional encodings.
        """
        super().__init__()
        # Create position indices: pos = [0, 1, ..., max_seq_len-1]
        pos_indices = torch.arange(max_seq_len, dtype=torch.float32)
        
        # Create dimension indices: dim = [0, 1, ..., d_model-1]
        dim_indices = torch.arange(d_model, dtype=torch.float32)
        
        # Compute the scaling exponent: 2 * floor(dim/2) / d_model
        exponent = ((dim_indices // 2) * 2) / d_model
        
        # Compute the denominator term: 10000^(exponent)
        div_term = torch.pow(10000, exponent)
        
        # Compute the angle rates: pos / div_term
        angle_rates = pos_indices.unsqueeze(1) / div_term
        
        # Initialize the positional encoding matrix and apply sine to even 
        # indices and cosine to odd indices.
        pos_encoding = torch.zeros_like(angle_rates)
        pos_encoding[:, 0::2] = torch.sin(angle_rates[:, 0::2])
        pos_encoding[:, 1::2] = torch.cos(angle_rates[:, 1::2])
        
        self.position_encoding = pos_encoding

    def forward(self):
        return self.position_encoding

In [18]:
PositionalEncode()()

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,
          1.0366e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,
          2.0733e-04,  1.0000e+00],
        ...,
        [ 1.7612e-02, -9.9984e-01, -9.9954e-01,  ...,  9.9399e-01,
          1.0564e-01,  9.9440e-01],
        [-8.3182e-01, -5.5504e-01, -5.4457e-01,  ...,  9.9398e-01,
          1.0575e-01,  9.9439e-01],
        [-9.1649e-01,  4.0007e-01,  3.7906e-01,  ...,  9.9396e-01,
          1.0585e-01,  9.9438e-01]])