In [1]:
# llama attention mechanism

In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA available:", torch.cuda.is_available())
    print("Number of GPUs:", torch.cuda.device_count())

CUDA available: True
Number of GPUs: 1


In [4]:
# --- Configuration --- # 

hidden_size = 128
# Dimension of the embedding vector for each token.
# Example: every token (like "dog") is represented by a vector of 128 numbers.

num_attention_heads = 16
# How many attention heads we use in parallel.
# The hidden_size (128) is split across these heads.
# Each head sees only part of the vector → 128 / 16 = 8 dimensions per head.

num_key_value_heads = 4
# Special trick: Grouped-Query Attention (GQA).
# Instead of creating 16 different Key/Value heads,
# we only create 4 K/V heads (each of size 8) and let multiple Q heads share them.
# → This saves memory and computation while keeping good performance.

head_dim = hidden_size // num_attention_heads
# Size of each head’s Q, K, V vector.
# With hidden_size=128 and 16 heads, each head works in 8 dimensions.

max_position_embeddings = 256
# Maximum sequence length (number of tokens) the model can process at once.
# If a sentence has more than 256 tokens, it must be truncated or split.

rope_theta = 10000.0
# Base frequency for Rotary Position Embeddings (RoPE).
# It controls how positional information is encoded.
# Larger theta = slower change in frequency = smoother positional encoding.

rms_norm_eps = 1e-5
# Tiny constant added inside RMSNorm to avoid division by zero.
# Ensures stability in training and inference.

attention_bias = False
# Whether to add a bias term to the linear layers that produce Q, K, V.
# Usually kept False for efficiency.

attention_dropout = 0.0
# Dropout probability applied to attention weights (to prevent overfitting).
# Often set to 0.0 during inference (disabled).

use_qk_norm = True
# Whether to normalize Q and K vectors (L2 norm) before computing attention scores.
# This keeps dot products more stable and avoids extreme attention weights

In [5]:
# --- Sample input setup ---

batch_size = 2  
# Number of independent sequences (context windows) processed in parallel.
# Example: 2 separate sentences.

sequence_length = 10  
# Number of tokens in each sequence (the length of the context window).

hidden_states = torch.randn(batch_size, sequence_length, hidden_size)  
# Random embeddings for each token in each sequence.
# Shape = (batch_size, sequence_length, hidden_size)
#        = (2, 10, 128)
# Meaning:
# - 2 sequences
# - Each sequence has 10 tokens
# - Each token is represented by a 128-dimensional vector

# --- Position IDs creation ---

position_ids = torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1)
# position_ids = torch.arange(0, sequence_length)  
# → Shape: (10,)
# → [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# position_ids = position_ids.unsqueeze(0)  
# Add a new dimension at the front
# → Shape: (1, 10)
# → [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]

# position_ids = position_ids.repeat(batch_size, 1)  
# Repeat the row for each sequence in the batch
# → Shape: (2, 10)
# → [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
#    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]

# Intuition:
# - Each token in each sequence needs a position ID.
# - Both sequences start at position 0, because they are independent windows.

# Create a causal attention mask
# Goal: make sure each token can only see itself and tokens before it (no looking into the future)

attention_mask = torch.triu(torch.ones(sequence_length, sequence_length) * -torch.inf, diagonal=1)
# Step 1: Make a square matrix (seq x seq).
# -∞ above the diagonal = future tokens (blocked)
# 0 on and below diagonal = current/past tokens (allowed)

attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq, seq)
# Step 2: Add two extra dimensions so the mask matches attention shapes.
# Now we have 4D: [1, 1, seq, seq]

attention_mask = attention_mask.expand(batch_size, 1, -1, -1)  # Shape: (batch, 1, seq, seq)
# Step 3: Copy the mask for each sequence in the batch.
# The "1" in heads dimension means the same mask is shared across all attention heads.
print("Configuration:")
print(f"  hidden_size: {hidden_size}")
print(f"  num_attention_heads: {num_attention_heads}")
print(f"  num_key_value_heads: {num_key_value_heads}")
print(f"  head_dim: {head_dim}")

print("\nSample Input Shapes:")
print(f"  hidden_states: {hidden_states.shape}")
print(f"  position_ids: {position_ids.shape}")
print(f"  attention_mask: {attention_mask.shape}")

Configuration:
  hidden_size: 128
  num_attention_heads: 16
  num_key_value_heads: 4
  head_dim: 8

Sample Input Shapes:
  hidden_states: torch.Size([2, 10, 128])
  position_ids: torch.Size([2, 10])
  attention_mask: torch.Size([2, 1, 10, 10])


In [6]:
# ## Q, K, V Projections
#
# The first step of attention: project the input hidden states into
# Query (Q), Key (K), and Value (V) spaces using linear layers.
#
# - Q = "what am I looking for?" (the current token’s query)
# - K = "what can I offer?" (the key of each token in the sequence)
# - V = "what information do I carry?" (the value of each token)
#
# In Llama (and many modern transformers) uses GQA = Grouped-Query Attention:
# - There are more Q heads (16 here) than K/V heads (4 here).
# - Multiple Q heads share the same K/V heads.
# - This reduces memory/computation without losing much performance.


# --- Define linear projection layers ---
# Each Linear layer creates a weight matrix W and (optionally) a bias vector b.
# PyTorch stores W with shape (out_features, in_features).
# During forward, the computation is:  output = input @ W.T + b

# Q projection: from hidden_size=128 → num_attention_heads * head_dim = 16*8=128
# So shape of Wq is (128, 128). Each token gets projected into 16 Q-heads (each of size 8).
q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)

# K projection: from hidden_size=128 → num_key_value_heads * head_dim = 4*8=32
# So shape of Wk is (32, 128). Each token gets projected into 4 K-heads (each of size 8).
k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)

# V projection: same as K (4 heads of size 8).
v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)

# O projection: after attention, we concatenate the 16 Q-head outputs back into
# a single vector (size 128). o_proj maps it back into hidden_size=128.
o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)


# --- Apply projections to the hidden states ---
# hidden_states has shape [batch_size, seq_len, hidden_size]
# Each linear layer is applied independently to every token in the batch.

# Q projection: output shape [batch_size, seq_len, num_attention_heads*head_dim] = [B, S, 128]
query_states = q_proj(hidden_states)

# K projection: output shape [batch_size, seq_len, num_key_value_heads*head_dim] = [B, S, 32]
key_states = k_proj(hidden_states)

# V projection: output shape [batch_size, seq_len, num_key_value_heads*head_dim] = [B, S, 32]
value_states = v_proj(hidden_states)

# Reshape queries, keys, values into [B, num_heads, S, head_dim] for multi-head attention
# --------------------------------------------------------------
# Before view: 
#   query_states shape = [B, S, hidden_size] = [1, 3, 4]
#   (for each token we just have 4 numbers, flattened)
#
# Step 1 (view): cut hidden_size=4 into (num_heads=2, head_dim=2)
#   query_states.view(1, 3, 2, 2) → [1, 3, 2, 2]
#
#   For token 0: [q0, q1 | q2, q3]   → head0=[q0,q1], head1=[q2,q3]
#   For token 1: [q4, q5 | q6, q7]   → head0=[q4,q5], head1=[q6,q7]
#   For token 2: [q8, q9 | q10,q11]  → head0=[q8,q9], head1=[q10,q11]
#
# Step 2 (transpose): put heads dimension before sequence length
#   query_states.transpose(1, 2) → [1, 2, 3, 2]
#
#   Now we can think like:
#     for b in batch:        # here b=0
#       for h in heads:      # h=0..1
#         query_states[b,h,:,:] = all tokens for this head
#
#   Example:
#     query_states[0,0,:,:] = [[q0,q1], [q4,q5], [q8,q9]]  # head 0 across all tokens
#     query_states[0,1,:,:] = [[q2,q3], [q6,q7], [q10,q11]]# head 1 across all tokens
#
# Keys and values go through the same reshape+transpose,
# but with num_key_value_heads=4 in Llama (instead of 16 for queries).
# So shapes end up:
#   Q: [B, num_heads, S, head_dim]
#   K: [B, num_kv_heads, S, head_dim]
#   V: [B, num_kv_heads, S, head_dim]
query_states = query_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1, 2)

print("Projected Shapes:")
print(f"  query_states: {query_states.shape}") # (batch_size, num_attention_heads, sequence_length, head_dim)
print(f"  key_states: {key_states.shape}")     # (batch_size, num_key_value_heads, sequence_length, head_dim)
print(f"  value_states: {value_states.shape}")   # (batch_size, num_key_value_heads, sequence_length, head_dim)

num_key_value_groups = num_attention_heads // num_key_value_heads
print(f"\nNum Key/Value Groups (Q heads per K/V head): {num_key_value_groups}")

Projected Shapes:
  query_states: torch.Size([2, 16, 10, 8])
  key_states: torch.Size([2, 4, 10, 8])
  value_states: torch.Size([2, 4, 10, 8])

Num Key/Value Groups (Q heads per K/V head): 4
