In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

In [14]:
## STEP 1: Setup the configurations for the different hyperparameters

In [35]:
emb_size = 10
num_attention_heads = 2
qkv_vec_dim = 8

head_dim = qkv_vec_dim // num_attention_heads
assert head_dim % 2 == 0, 'Head dimension must be an even number'

In [4]:
max_seq_len = 128
rope_theta = 10000.0 # Base angle for RoPE frequency calculation
rms_norm_eps = 1e-5 # Epsilon for RMSNorm
using_attention_bias = False # Whether to use bias in Q, K, V, O
normalising_qk = True # Whether to apply L2 normalisaition to Q & K before attention

In [21]:
## STEP 2: Get the vector embeddings & position IDs of each token in the sequence.

In [16]:
# Sample Input
batch_size = 2
sequence_length = 9

In [17]:
# Create a matrix representing the embedding vectors for each token in the sequence
embeddings = torch.randn(batch_size, sequence_length, emb_size)
    # Shape: (batch_size, sequence_length, emb_size)
    # ShapeEx: (2, 9, 10)

In [20]:
position_ids = torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1)
    # Shape: (batch_size, sequence_length)
    # ShapeEx: (2, 9)

In [19]:
print(position_ids)

tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])


In [30]:
## STEP 3: Create the attention mask

In [25]:
attention_mask = torch.triu(torch.ones(sequence_length, sequence_length) * -torch.inf, diagonal=1)
    # torch.triu() keeps the above main diagonal, while setting main diagonal & below to 0
    # Shape: (sequence_length, sequence_length)
    # ShapeEx: (9, 9)

In [23]:
print(attention_mask)

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])


In [26]:
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
    # Shape: (1, 1, sequence_length, sequence_length)
    # ShapeEx: (1, 1, 9, 9)
attention_mask = attention_mask.expand(batch_size, 1, -1, -1)
    # Shape: (2, 1, sequence_length, sequence_length)
    # ShapeEx: (2, 1, 9, 9)

In [31]:
## STEP 4: Create Q, K, V, O projection matrices 

In [39]:
# Q: Query of each token in the sequence
# K: Key of each token in the sequence
# V: Value or information of each token in the sequence

In [37]:
q_proj = nn.Linear(emb_size, qkv_vec_dim, bias=using_attention_bias)
k_proj = nn.Linear(emb_size, qkv_vec_dim, bias=using_attention_bias)
v_proj = nn.Linear(emb_size, qkv_vec_dim, bias=using_attention_bias)
    # Shape: (emb_size, qkv_vec_dim)
    # ShapeEx: (10, 8)

o_proj = nn.Linear(qkv_vec_dim, emb_size, bias=using_attention_bias)
    # Since its a bit like reversal of v_proj, its shape is also reversed
    # Shape: (qkv_vec_dim, emb_size)
    # ShapeEx: (8, 10)

In [42]:
print(q_proj)

Linear(in_features=8, out_features=4, bias=False)


In [40]:
## STEP 5: Compute Q, K, V

In [41]:
query_states = q_proj(embeddings)
key_states = k_proj(embeddings)
value_states = v_proj(embeddings)
    # Shape: (batch_size, sequence_length, qkv_vec_dim)
    # ShapeEx: (2, 9, 8)

In [48]:
query_states = query_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1, 2)
    # Shape: (batch_size, num_attention_heads, sequence_length, head_dim)
    # ShapeEx: (2, 2, 9, 8)

In [49]:
## STEP 6: Compute RoPE complex frequencies

In [50]:
# Rotary Positional Embeddings (RoPE) applies rotations to the Q & K vectors based on the position of the tokens,
# injecting relative positional information into Q & K before computing their dot product.

# RoPE represents embeddings in complex number space and rotates the embeddings by an angle proportional to the tokens' position.

In [54]:
def get_rope_freqs_cis(emb_dim, max_seq_len, base=10000.0, device=None):
    inv_freq = 1.0 / (base ** (torch.arange(0, emb_dim, 2, device=device).float() / emb_dim))
        # Shape: (head_dim // 2, )
        # ShapeEx: (2, )
    pos_indices = torch.arange(max_seq_len, device=device).type_as(inv_freq)
        # Shape: (max_seq_len, )
        # ShapeEx: (128, )
    freqs = torch.outer(pos_indices, inv_freq)
        # Shape: (max_seq_len, head_dim // 2)
        # ShapeEx: (128, 2)

    freqs_cis = torch.complex(freqs.cos(), freqs.sin())
        # Shape: (max_seq_len, head_dim // 2), complex 
        # ShapeEx: (128, 2), complex

    # freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

In [None]:
def apply_rope_to_qk(
    q_vecs: torch.Tensor,    # Shape: (batch, num_attention_heads, sequence_length, head_dim); ShapeEx: (2, 2, 9, 4)
    k_vecs: torch.Tensor,    # Shape: (batch, num_attention_heads, sequence_length, head_dim): ShapeEx: (2, 2, 9, 4)
    freqs_cis: torch.Tensor  # Shape: (max_seq_len, head_dim // 2), complex
) -> Tuple[torch.Tensor, torch.Tensor]:
    # 1. Select the correct rotation vectors based on the current sequence positions (truncate)
    #    position_ids has shape of (batch, sequence_length)
    #    This uses advanced indexing to pick rows from freqs_cis based on the position_ids
    freqs_cis = freqs_cis[position_ids]
        # Shape: (batch, sequence_length, head_dim // 2), complex
        # ShapeEx: (2, 9, 2), complex

    # 2. Add an extra dimension for broadcasting across attention heads
    freqs_cis = freqs_cis[:, None, :, :]
        # Shape: (batch, 1, sequence_length, head_dim // 2), complex
        # ShapeEx: (1, 1, 9, 2), complex

    # 3. Reshape Q & K to view adjacent pairs as complex numbers
        # q_vecs Shape: (batch, num_attention_heads, sequence_length, head_dim)
        #   -> (batch, num_attention_heads, sequence_length, head_dim // 2, 2)
        #   -> (batch, num_attention_heads, sequence_length, head_dim // 2), complex
        # ShapeEx: (2, 2, 9, 4) -> (2, 2, 9, 2, 2) -> (2, 2, 9, 2)
    q_vecs_complex = torch.view_as_complex(q_vecs.float().reshape(*q_vecs.shape[:-1], -1, 2))
    k_vecs_complex = torch.view_as_complex(k_vecs.float().reshape(*k_vecs.shape[:-1], -1, 2))

    # 4. Apply the RoPE rotation to Q & K using element-wise complex multiplication
    # q_vecs_complex Shape: (batch, num_attention_heads, sequence_length, head_dim // 2), complex
        # ShapeEx: (2, 2, 9, 2)
    # freqs_cis Shape: (batch, 1, sequence_length, head_dim // 2), complex
        # ShapeEx: (2, 1, 9, 2)
    rotated_q_vecs_complex = q_vecs_complex * freqs_cis
    rotated_k_vecs_complex = k_vecs_complex * freqs_cis

    # 5. Covert Q & K from complex numbers back to real numbers. Basically reversal of Step (3)
        # rotated_q_vecs_complex Shape: (batch, num_attention_heads, sequence_length, head_dim // 2), complex
        #   -> (batch, num_attention_heads, sequence_length, head_dim // 2, 2)
        #   -> (batch, num_attention_heads, sequence_length, head_dim)
        # ShapeEx: (2, 2, 9, 2) -> (2, 2, 9, 2, 2) -> (2, 2, 9, 4)
    rotated_q_vecs_real = torch.view_as_real(rotated_q_vecs_complex).flatten(3)
    rotated_k_vecs_real = torch.view_as_real(rotated_k_vecs_complex).flatten(3)
    
    return rotated_q_vecs_real.type_as(q_vecs), rotated_k_vecs_real.type_as(k_vecs)

In [None]:
freqs_cis = simple_rope_calculation(head_dim, max_seq_len, base=rope_theta, device=embeddings.device)
    # Shape: (max_seq_len, head_dim // 2), complex 
    # ShapeEx: (128, 2)

query_states_rope, key_states_rope = apply_rope_to_qk(query_states, key_states, freqs__cis)
    # Shape: (batch, num_attention_heads, sequence_length, head_dim)
    # ShapeEx: (2, 2, 9, 4)

In [59]:
# Optional L2 Normalization
# To be applied to Q & K after RoPE but before the attenton score calculation (the matrix multiplication)
class SimpleL2Norm(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps  # Epsilon value to avoid division by zero during normalization

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

In [None]:
if normalising_qk:
    qk_norm = SimpleL2Norm()
    query_states_final = qk_norm(query_states_rope)
    key_states_final = qk_norm(key_states_rope)
else:
    query_states_final = query_states_rope
    key_states_final = key_states_rope

In [None]:
attn_weights = torch.matmul(query_states_final, key_states_final)

scaling_factor = 1.0 / math.sqrt(head_dim)
attn_weights *= scaling_factor
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
    # Shape: (batch_size, num_attn_heads, sequence_length, sequence_length)
    # ShapeEx: (2, 2, 9, 9)

In [None]:
attn_output = torch.matmul(attn_weights, value_states)
    # Shape: (batch_size, num_attn_heads, sequence_length, head_dim)
    # ShapeEx: (2, 2, 9, 8)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, sequence_length, hidden_size)
# Shape: (batch_size, num_attn_heads, sequence_length, head_dim)
# -> (batch_size, sequence_length, num_attn_heads, head_dim)
# -> (batch, seq_len, num_attn_heads * head_dim) = (batch, seq_len, qkv_vec_dim)
final_attn_output = o_project(attn_output)
    # Shape: (batch, seq_len, emb_size)