<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/transformers/postional_encodings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Implements different types of positional encoding strategies:

# 1. Fixed encodings - Encodes sinusoidal positions
# 2. Learned embeddings - Learns positions
# 3. RoPE embeddings - Encodes absolute and relative positions through rotation matrices

import torch
import torch.nn as nn
import math

In [None]:
class FixedPositionalEncodings(nn.Module):
  '''
  Fixed positions were introduced in Attention Is All you Need paper.

  Unlike RNNs and LSTMs where relative positions are baked into model architecture, a transformer
  has to learn that token1 is next to token 2. If we do not apply positional encodings into the network,
  the model learns a bag of words representation where it loses all the order/positional info. See LDA and TD-IDF.
  1. NLP Classification: https://github.com/QasimWani/Classification/blob/master/Xenophobia/xenophobia%20classifier.ipynb
  2. HitcHiqe: https://github.com/QasimWani/hitchHiqe/blob/121fa21390a41d7e71c58d1f8002b1b7cefa0c2e/middleware/recommender.js#L294

  We want the model to learn the relative positions and have semantically different otuputs between [cat, sat, on, the, mat] and [sat, cat, on, the, mat]

  This is where positional encodings come into play. Fixed positional encodings compute frequencies to control the oscillation of different tokens and corresponding elements
  in the embedding (d_embed) for the model to learn the relative positions. A faster oscillation

  Design choice #1 - apply a sin/cos function to encode for relative position. Tokens that are next to one another will be closer in representation, and tokens that are farther apart might be farther. Range [-1, 1]
  Problem with this is that there will be a lot of repetition across the embedding dimension. In fact, every single value in d_embed will be equal so the model doesn't have enough info to learn and turns into a low-rank representation.

  Design choice #2 - instead of just applying a single sin/cos, we apply sine function for even positions in d_embed and cosine function for odd positions in d_embed.
  While this does help with reduced repetition, 50% of all values are still identical since the cos/sin is a function of range(0, max_len) and for each element in seq_len (token id) the sin/cos value is the same across d_embed.

  Design choice #3 - control the oscillation. To further let the model learn relative positions and reduce repetition across the embedding dimension, d_embed, let's use a scaling factor to change how fast sine and cosine function oscillates across d_embed.
  Scaling factor: e^(idx * (- log(10k) / d_embed)), where idx in range(0, d_embed, step=2).
  This is equivalent to 10k ^ (-idx / d_embed), where idx is in range(0, d_embed, step=2). Law of exponents
  Note: we need to skip half of all postions because sin/cos is applied to 50% of the data individually.

  In the earlier positions of d_embed the oscillation is pretty high, exp(-0 * log(10k) / d_embed) >>  exp(-766 * log(10k) / d_embed).
  You can think of it as the initial positions encode fine-grained position info, while the latter half might encode more general/coarse positional info.


  Pros:
  1. Encoding both relative and absolute positions to tokens is how we're able to solve the bag of words problem.
  2. Because it's not learned and is fixed, it extrapolates by design
  3. No parameters to learn. Fixed based on max_seq_len and d_embed
  4. Bounded values [-1, 1] prevents gradient overflows

  Cons:
  1. Transformer model needs to learn sin/cos relationships to be able to understand the relative encodings. Harder task than just baking it in
  somehow (see RoPE). For example, if a sentence is shuffled yet contains the same semantic information, the positional encodings will be different.
  In particular if you have 'I ate an apple by the bench', this will have vastly different positional encodings than 'by the bench I ate an apple'.
  Being able to bake in both absolute and relative positions is important, but at the same time we should not let the model learn complex sin/cos
  relationships in the original embedding vector since it complicates learning.
  2. Way less interpretable than learned embeddings or geometric approaches like RoPE.
  3. Only applied once per input. Further attention layers don't have positional information baked in d_embed
  '''

  def __init__(self, max_len, d_embed):
    super().__init__()
    assert d_embed % 2 == 0, f"Embedding dimension needs to be even, other sin/cos will cause dim mismatch"

    positional_encodings = torch.zeros(max_len, d_embed) # (max_len, d_embed)
    positions = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1)

    div_term = torch.exp(torch.arange(0, d_embed, 2) * (-math.log(1e4) / d_embed)) # equivalent to: 1 / 1e4 ** (-torch.arange(0, d_embed, 2) / d_embed)

    positional_encodings[:, 0::2] = torch.sin(positions * div_term) # (max_len, d_embed)
    positional_encodings[:, 1::2] = torch.cos(positions * div_term) # (max_len, d_embed)

    self.register_buffer('positional_encodings', positional_encodings)


  def forward(self, x: torch.Tensor) -> torch.Tensor:
    '''
    x.shape = batch_size, seq_len, d_embed
    '''
    seq_len = x.size(1)
    return x + self.positional_encodings[:seq_len, :].unsqueeze(0) # Note: seq_len <= max_len. To prevent mismatch errors, we should truncate the positional encodings to only use the value up until current seq_len that came from the input


In [None]:
class LearnedPositionalEncodings(nn.Module):
  '''
  Instead of fixing positional encodings, why not just learn it? Simple weight matrix (no bias, hence embedding)
  to learn the encodings over training.

  Pros:
  1. Super simple
  2. May learn flexible encodings that may not be baked in by some feature engineered design.
  3. Fine for models with small context lengths

  Cons:
  1. Parameter count can be massive - max_len x d_embed is huge if large context window.
  2. No extrapolation. Model can't learn positional encodings beyond max_len.
  2. No inductive bias - the model has to learn this representation from scratch. Will learn to memorize data statistics, prone to data noise/imbalance
  '''

  def __init__(self, max_len, d_embed):
    super().__init__()

    self.positional_encodings = nn.Embedding(max_len, d_embed)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    batch_size, seq_len, d_embed = x.shape
    positions = torch.arange(0, seq_len, device=x.device)
    return x + self.positional_encodings(positions).unsqueeze(0)

In [None]:
class RotaryPositionalEmbeddings(nn.Module):
  '''
  One of the major challenges with Fixed encoding like the one proposed in Attention Is All you Need is that the model needs to learn trig functions to truly
  understand the positional relationship. For example, 'I ate an apple by the bench' will have vastly different positional encodings than
  'by the bench I ate an apple' and the model needs to understand why.

  To solve for this we make use of a rotation matrix where the dot product between two vectors, query and key preserve the distance between tokens m and n.
  A rotation matrix is defined as:
  [cos -sin]
  [sin  cos]

  This comes from polar coordinates and making use of the fact that in a unit circle, hypotenuse is 1, which means that to transform a vector [1, 0] (x0, x1)
  in the x1 direction by some angle, theta, we need to transform it by cos in the x0 direction and sin in the x1. In other words, T([1, 0]) = [cos theta, sin theta].
  Similarly, to transform a vector [0, 1] (x0, x1) by some angle, theta, we get T([0, 1]) = [-sin theta, cos theta].

  One nice property of this rotation matrix is that it doesn't change the values between two vectors since it's relative. So we can encode relative positions
  into our vectors implicitly. Now, we want to also bake in absolute positional information. This is done by simply scaling the theta by the position in seq_len.
  identical to `FixedPositionalEncodings`.

  An important distinction between RoPE and other positional encoding techniques is that we're applying it directly to Q and K projections so we can bake in the
  relative property, otherwise we learn just the absolute positional information and is not that much different than the `FixedPositionalEncodings` formulation.

  Formula: Attention_m,n = (X_m W_q R_theta^m) @ (X_n W_k R_theta^n).T, where m and n are two positions

  Pros:
  1. Relative position emerges directly from dot product (no learning trig identities needed)
  2. Applied at every attention layer, not just the first/pre-processing. This means that each layer of the network has positional info baked in.
  3. No parameters to learn
  4. Extrapolates to longer sequences by continuing to rotate by position index and frequency term

  Cons:
  1. Complex, especially the implementation - there's ones that make use of half rotation to get rid of sparsity in the `rotation_matrix`
  2. Frequency needs to be tuned for super long context length. NTK and position interpolation help

  '''
  def __init__(self, max_len, d_embed):
    super().__init__()
    # earlier positions in d_embed receive higher frequency while later positions receive lower frequency, meaning they move much slower
    # there are ways of increasing context window at inference through NTK and position interpolation constant scaling parameter (tr_seq_len / ts_seq_len)
    freqs = 1.0 / (1e4 ** (torch.arange(0, d_embed, 2).float() / d_embed)) # d_embed / 2

    self.register_buffer('frequency', freqs)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    '''
    x.shape = batch_size, seq_len, d_embed
    '''
    # NOTE: Phase dictates positions in seq_len while frequence dictates rate of oscillation in d_embed
    batch_size, seq_len, d_embed = x.shape

    x = x.reshape(batch_size, seq_len, d_embed // 2, 2) # first component of last dimension will be scaled with cos (x0) and second component will be scaled with sin (x1)

    positions = torch.arange(0, seq_len, device=x.device).unsqueeze(-1) # seq_len, 1
    angles = positions * self.freqs # seq_len, d_embed / 2

    # Create rotation matrix components
    # [cos -sin] [x0]   [x0 * cos  -  x1 * sin]
    # [sin  cos] [x1] = [x0 * sin  +  x1 * cos]
    cos = torch.cos(angles).reshape(1, seq_len, d_embed // 2) # 1, seq_len, d_embed / 2
    sin = torch.sin(angles).reshape(1, seq_len, d_embed // 2) # 1, seq_len, d_embed / 2

    x_rotated = torch.stack([
        x[..., 0] * cos - x[..., 1] * sin,
        x[..., 0] * sin + x[..., 1] * cos
    ], dim=-1) # batch_size, seq_len, d_embed / 2, 2

    return x_rotated.view(batch_size, seq_len, d_embed)


# usage - apply to q and k, not to embeddings directly!
def apply_rope(Q, K, max_len, d_embed):
  rope = RotaryPositionalEmbeddings(max_len, d_embed) # NOTE: in production, initialize just once
  Q_rot = rope(Q)
  K_rot = rope(K)
  # attention uses rotated q,k
  scores = torch.matmul(Q_rot, K_rot.transpose(-2, -1))
  return scores
