# Solutions: Lab 2.3.3 - Positional Encoding Study

This notebook contains solutions to the exercises from notebook 03.

---

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import math

torch.manual_seed(42)

## Exercise 1: Implement RoPE (Rotary Position Embeddings)

**Task:** Implement Rotary Position Embeddings used in LLaMA and modern LLMs.

In [None]:
class RoPE(nn.Module):
    """
    Rotary Position Embeddings.
    
    Applies rotation to Q and K based on position, making the
    dot product naturally encode relative position.
    """
    
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        
        assert dim % 2 == 0, "Dimension must be even for RoPE"
        
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        
        # Compute inverse frequencies: theta_i = 1 / (base^(2i/dim))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # Pre-compute cos and sin
        self._build_cache(max_seq_len)
    
    def _build_cache(self, seq_len):
        """Pre-compute sin/cos values."""
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.outer(t, self.inv_freq)  # (seq_len, dim/2)
        
        # Create [cos, cos] and [sin, sin] for each position
        emb = torch.cat([freqs, freqs], dim=-1)  # (seq_len, dim)
        
        self.register_buffer('cos_cached', emb.cos(), persistent=False)
        self.register_buffer('sin_cached', emb.sin(), persistent=False)
    
    def _rotate_half(self, x):
        """Rotate half the hidden dims of x."""
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat([-x2, x1], dim=-1)
    
    def forward(self, q, k, seq_len=None):
        """
        Apply RoPE to queries and keys.
        
        Args:
            q: Query tensor (batch, heads, seq_len, dim)
            k: Key tensor (batch, heads, seq_len, dim)
            seq_len: Sequence length (inferred if not provided)
        """
        if seq_len is None:
            seq_len = q.size(2)
        
        # Extend cache if needed
        if seq_len > self.max_seq_len:
            self._build_cache(seq_len)
            self.max_seq_len = seq_len
        
        cos = self.cos_cached[:seq_len].to(q.dtype)
        sin = self.sin_cached[:seq_len].to(q.dtype)
        
        # Reshape for broadcasting: (1, 1, seq, dim)
        cos = cos.unsqueeze(0).unsqueeze(0)
        sin = sin.unsqueeze(0).unsqueeze(0)
        
        # Apply rotation: x * cos + rotate_half(x) * sin
        q_embed = (q * cos) + (self._rotate_half(q) * sin)
        k_embed = (k * cos) + (self._rotate_half(k) * sin)
        
        return q_embed, k_embed

# Test RoPE
rope = RoPE(dim=64, max_seq_len=1024)
q = torch.randn(2, 8, 100, 64)  # (batch, heads, seq, dim)
k = torch.randn(2, 8, 100, 64)

q_rot, k_rot = rope(q, k)

print(f"Input Q shape: {q.shape}")
print(f"Output Q shape: {q_rot.shape}")
print(f"\nRoPE preserves shape and applies position-dependent rotation!")

## Exercise 2: Implement ALiBi (Attention with Linear Biases)

**Task:** Implement ALiBi which adds linear biases based on relative position.

In [None]:
class ALiBi(nn.Module):
    """
    Attention with Linear Biases.
    
    Adds a position-dependent bias to attention scores:
    attention_scores = Q @ K^T + alibi_bias
    """
    
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
        
        # Compute slopes for each head
        slopes = self._get_slopes(num_heads)
        self.register_buffer('slopes', torch.tensor(slopes).float())
    
    def _get_slopes(self, num_heads):
        """Compute ALiBi slopes for each head."""
        def get_slopes_power_of_2(n):
            start = 2 ** (-8 / n)
            ratio = start
            return [start * (ratio ** i) for i in range(n)]
        
        if math.log2(num_heads).is_integer():
            return get_slopes_power_of_2(num_heads)
        else:
            # For non-power-of-2, interpolate
            closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
            slopes = get_slopes_power_of_2(closest_power_of_2)
            extra = get_slopes_power_of_2(2 * closest_power_of_2)[0::2]
            return slopes + extra[:num_heads - closest_power_of_2]
    
    def forward(self, seq_len):
        """
        Compute ALiBi bias matrix.
        
        Returns:
            Bias tensor (num_heads, seq_len, seq_len)
        """
        # Create relative position matrix
        positions = torch.arange(seq_len, device=self.slopes.device)
        relative_pos = positions.unsqueeze(0) - positions.unsqueeze(1)
        
        # Apply negative absolute value (closer = higher score)
        bias = -torch.abs(relative_pos).float()
        
        # Scale by head-specific slope
        bias = bias.unsqueeze(0) * self.slopes.unsqueeze(-1).unsqueeze(-1)
        
        return bias

# Test ALiBi
alibi = ALiBi(num_heads=8)
bias = alibi(seq_len=10)

print(f"ALiBi bias shape: {bias.shape}")
print(f"\nBias for head 0 (first 5x5):")
print(bias[0, :5, :5].numpy().round(3))

# Visualize
plt.figure(figsize=(10, 4))
for i in range(4):
    plt.subplot(1, 4, i+1)
    plt.imshow(bias[i].numpy(), cmap='RdBu_r')
    plt.title(f'Head {i}')
    plt.colorbar()
plt.suptitle('ALiBi Bias Patterns for Different Heads')
plt.tight_layout()
plt.show()

## Exercise 3: Compare Position Encoding Methods

**Task:** Compare how different position encodings affect attention patterns.

In [None]:
import torch.nn.functional as F

def compare_position_encodings():
    """
    Compare different position encoding methods.
    """
    seq_len = 50
    d_model = 64
    num_heads = 4
    
    # Create random Q, K
    Q = torch.randn(1, num_heads, seq_len, d_model // num_heads)
    K = torch.randn(1, num_heads, seq_len, d_model // num_heads)
    
    # 1. No position encoding (baseline)
    scores_none = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_model // num_heads)
    attn_none = F.softmax(scores_none, dim=-1)
    
    # 2. With RoPE
    rope = RoPE(dim=d_model // num_heads)
    Q_rope, K_rope = rope(Q, K)
    scores_rope = torch.matmul(Q_rope, K_rope.transpose(-2, -1)) / math.sqrt(d_model // num_heads)
    attn_rope = F.softmax(scores_rope, dim=-1)
    
    # 3. With ALiBi
    alibi = ALiBi(num_heads)
    alibi_bias = alibi(seq_len)
    scores_alibi = scores_none + alibi_bias
    attn_alibi = F.softmax(scores_alibi, dim=-1)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].imshow(attn_none[0, 0].detach().numpy(), cmap='Blues')
    axes[0].set_title('No Position Encoding')
    axes[0].set_xlabel('Key Position')
    axes[0].set_ylabel('Query Position')
    
    axes[1].imshow(attn_rope[0, 0].detach().numpy(), cmap='Blues')
    axes[1].set_title('With RoPE')
    axes[1].set_xlabel('Key Position')
    
    axes[2].imshow(attn_alibi[0, 0].detach().numpy(), cmap='Blues')
    axes[2].set_title('With ALiBi')
    axes[2].set_xlabel('Key Position')
    
    plt.suptitle('Attention Patterns with Different Position Encodings (Head 0)')
    plt.tight_layout()
    plt.show()
    
    print("Observations:")
    print("- No position: Attention based purely on content")
    print("- RoPE: Relative position affects Q-K similarity")
    print("- ALiBi: Strong bias toward nearby positions (local attention)")

compare_position_encodings()

---

## Key Takeaways

1. **RoPE** encodes position by rotating Q and K, making dot product sensitive to relative position
2. **ALiBi** adds a linear bias that penalizes distant tokens, promoting local attention
3. Both methods enable better length extrapolation than absolute position embeddings
4. RoPE is used in LLaMA, Mistral, etc.; ALiBi is used in BLOOM, MPT

---