# Lab 2.3.3: Positional Encoding Study

**Module:** 2.3 - Natural Language Processing & Transformers  
**Time:** 2 hours  
**Difficulty:** ⭐⭐⭐

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Understand WHY position information is needed in Transformers
- [ ] Implement sinusoidal positional encodings (original Transformer)
- [ ] Implement Rotary Position Embeddings (RoPE, used in LLaMA)
- [ ] Understand ALiBi and other modern approaches
- [ ] Compare different strategies and their trade-offs

---

## Prerequisites

- Completed: Notebooks 01-02 (Attention and Transformer blocks)
- Knowledge of: Trigonometry basics (sin, cos), complex numbers (helpful for RoPE)

---

## Real-World Context

**Different models use different position strategies:**
- **Original Transformer, BERT**: Sinusoidal or learned embeddings
- **GPT-2, GPT-3**: Learned positional embeddings
- **LLaMA, Mistral, Qwen**: Rotary Position Embeddings (RoPE)
- **BLOOM, MPT**: ALiBi (Attention with Linear Biases)

The choice affects:
- Maximum context length
- Ability to extrapolate to longer sequences
- Computational efficiency

---

## ELI5: Why Do Transformers Need Position Information?

> **Imagine reading a shuffled book.**
>
> Without page numbers, you'd have no idea what order the pages go in! The sentence "The cat ate the mouse" would look the same as "The mouse ate the cat" to a position-blind model.
>
> **The problem:** Self-attention treats input as a SET, not a SEQUENCE.
> - It computes attention between all pairs of words
> - Nothing tells it that word 1 comes before word 2
> - "I like cats" = "cats like I" = "like I cats" (all the same!)
>
> **The solution:** Add position information to each word embedding.
> - Word embedding: "What word is this?"
> - Position encoding: "Where is this word in the sequence?"
> - Combined: "This is the word 'cat' at position 3"

---

## Part 1: Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

# Try to import seaborn for enhanced visualizations
HAS_SEABORN = False
try:
    import seaborn as sns
    HAS_SEABORN = True
except ImportError:
    print("⚠️ seaborn not installed. Using matplotlib defaults.")

# Set up plotting
plt.style.use("default")

# ============================================================
# Visualization Constants (L3 fix: extract magic numbers)
# ============================================================
FIGURE_SIZE_SMALL = (10, 4)       # For simple 2-panel plots
FIGURE_SIZE_MEDIUM = (12, 5)      # For comparison plots
FIGURE_SIZE_LARGE = (14, 10)      # For multi-panel visualizations
FIGURE_SIZE_WIDE = (15, 4)        # For wide horizontal layouts
FIGURE_SIZE_GRID = (16, 8)        # For grid layouts

# Common plot settings
COLORMAP_SEQUENTIAL = "Blues"
COLORMAP_DIVERGING = "RdBu"
COLORMAP_VIRIDIS = "viridis"
GRID_ALPHA = 0.3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

torch.manual_seed(42)

---

## Part 2: Demonstrating the Problem

In [None]:
# Demonstrate that attention is permutation-invariant

def simple_attention(Q, K, V):
    """Simple self-attention without position info."""
    d_k = K.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V)

# Create some "word embeddings"
words = ["The", "cat", "sat"]
embeddings = torch.randn(1, 3, 8)  # (batch, seq, dim)

# Original order
out_original = simple_attention(embeddings, embeddings, embeddings)

# Shuffled order (swap positions 1 and 2)
shuffled = embeddings[:, [0, 2, 1], :]
out_shuffled = simple_attention(shuffled, shuffled, shuffled)

# Unshuffle the output to compare
out_shuffled_unshuffled = out_shuffled[:, [0, 2, 1], :]

# Compare
print("Attention is permutation-equivariant!")
print(f"Output for 'The cat sat': shape {out_original.shape}")
print(f"Output for 'The sat cat' (unshuffled): shape {out_shuffled_unshuffled.shape}")
print(f"Are they the same? {torch.allclose(out_original, out_shuffled_unshuffled, atol=1e-6)}")
print("\n⚠️ Without position info, the model can't tell word order!")

---

## Part 3: Sinusoidal Positional Encoding

### ELI5: The Sinusoidal Approach

> **Imagine a clock with many hands.**
>
> - The second hand moves fastest (high frequency)
> - The minute hand moves slower
> - The hour hand moves slowest (low frequency)
>
> **By looking at all the hands, you can tell exactly what time it is!**
>
> Sinusoidal encoding uses the same idea:
> - Different dimensions oscillate at different frequencies
> - Position 0 has a unique combination of values
> - Position 1 has a different combination
> - Each position gets a unique "fingerprint"

### The Formula

```
PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
```

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding from "Attention Is All You Need".
    
    Uses sin and cos functions at different frequencies to encode position.
    """
    
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        """
        Args:
            d_model: Model dimension
            max_len: Maximum sequence length to pre-compute
            dropout: Dropout probability
        """
        super().__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        
        # Position indices: [0, 1, 2, ..., max_len-1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Dimension indices for the exponential term
        # div_term = 10000^(2i/d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        
        # Apply sin to even indices, cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)  # Even dimensions
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd dimensions
        
        # Add batch dimension and register as buffer (not a parameter)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        """
        Add positional encoding to input.
        
        Args:
            x: Input tensor (batch, seq_len, d_model)
            
        Returns:
            Output tensor with position info added
        """
        seq_len = x.size(1)
        # Add position encoding (broadcasting handles batch dimension)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

# Create and test
d_model = 64
pe = SinusoidalPositionalEncoding(d_model, max_len=100)

# Test with dummy input
x = torch.zeros(1, 20, d_model)  # All zeros, so output = just PE
out = pe(x)
print(f"Positional encoding shape: {out.shape}")

In [None]:
# Visualize the positional encoding

def visualize_sinusoidal_pe(d_model=64, max_len=100):
    """Visualize sinusoidal positional encoding patterns."""
    
    pe = SinusoidalPositionalEncoding(d_model, max_len)
    encoding = pe.pe[0].numpy()  # (max_len, d_model)
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Full heatmap
    ax = axes[0, 0]
    im = ax.imshow(encoding[:50, :32], aspect='auto', cmap='RdBu')
    ax.set_xlabel('Dimension')
    ax.set_ylabel('Position')
    ax.set_title('Positional Encoding Heatmap\n(first 50 positions, 32 dims)')
    plt.colorbar(im, ax=ax)
    
    # Show different frequency waves
    ax = axes[0, 1]
    for dim in [0, 4, 16, 32, 48]:
        ax.plot(encoding[:50, dim], label=f'dim {dim}')
    ax.set_xlabel('Position')
    ax.set_ylabel('Value')
    ax.set_title('Different Dimensions = Different Frequencies')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Position similarity matrix
    ax = axes[1, 0]
    positions = encoding[:20]
    similarity = np.dot(positions, positions.T)
    similarity = similarity / (np.linalg.norm(positions, axis=1, keepdims=True) @ 
                               np.linalg.norm(positions, axis=1, keepdims=True).T)
    im = ax.imshow(similarity, cmap='viridis')
    ax.set_xlabel('Position')
    ax.set_ylabel('Position')
    ax.set_title('Position Similarity\n(nearby positions are more similar)')
    plt.colorbar(im, ax=ax)
    
    # Relative position demo
    ax = axes[1, 1]
    # Show that PE(pos) and PE(pos+k) have a linear relationship
    pos_0 = encoding[0]
    pos_5 = encoding[5]
    pos_10 = encoding[10]
    pos_15 = encoding[15]
    
    # Difference vectors
    diff_0_5 = pos_5 - pos_0
    diff_10_15 = pos_15 - pos_10
    
    ax.plot(diff_0_5[:32], label='PE(5) - PE(0)', alpha=0.8)
    ax.plot(diff_10_15[:32], label='PE(15) - PE(10)', alpha=0.8)
    ax.set_xlabel('Dimension')
    ax.set_ylabel('Difference')
    ax.set_title('Relative Position is Consistent\n(same k-offset has similar pattern)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

visualize_sinusoidal_pe()

### Key Properties of Sinusoidal PE:

1. **Unique encoding**: Each position has a unique pattern
2. **Bounded**: Values are always between -1 and 1
3. **Relative position**: PE(pos+k) can be expressed as a linear function of PE(pos)
4. **Extrapolation**: Can theoretically handle longer sequences than seen during training

---

## Part 4: Learned Positional Embeddings

GPT-2 and many models use learned position embeddings instead.

In [None]:
class LearnedPositionalEmbedding(nn.Module):
    """
    Learned positional embeddings.
    
    Each position has a learnable embedding vector.
    Used in GPT-2, GPT-3, BERT.
    """
    
    def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
        super().__init__()
        
        # Learnable embedding for each position
        self.embedding = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(p=dropout)
        
        # Initialize with small random values
        nn.init.normal_(self.embedding.weight, mean=0, std=0.02)
        
    def forward(self, x):
        """
        Add position embeddings to input.
        
        Args:
            x: Input tensor (batch, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape
        
        # Create position indices [0, 1, 2, ..., seq_len-1]
        positions = torch.arange(seq_len, device=x.device)
        
        # Get position embeddings
        pos_emb = self.embedding(positions)  # (seq_len, d_model)
        
        # Add to input
        x = x + pos_emb
        return self.dropout(x)

# Compare parameters
d_model = 768
max_len = 512

sinusoidal = SinusoidalPositionalEncoding(d_model, max_len)
learned = LearnedPositionalEmbedding(d_model, max_len)

print("Parameter comparison:")
print(f"  Sinusoidal: {sum(p.numel() for p in sinusoidal.parameters()):,} (none! it's fixed)")
print(f"  Learned:    {sum(p.numel() for p in learned.parameters()):,} ({max_len} x {d_model})")

### Trade-offs:

| Aspect | Sinusoidal | Learned |
|--------|------------|----------|
| Parameters | 0 | max_len × d_model |
| Extrapolation | Yes (theoretically) | No (limited to max_len) |
| Performance | Good | Sometimes slightly better |
| Training | Fixed | Learns from data |

---

## Part 5: Rotary Position Embeddings (RoPE)

### ELI5: RoPE

> **Imagine a compass needle.**
>
> - At position 0, the needle points North
> - At position 1, it rotates slightly
> - At position 2, it rotates more
> - The ANGLE tells you the position!
>
> **RoPE rotates the query and key vectors by an angle based on position.**
> - When computing attention, the dot product naturally encodes relative position
> - No need to add anything to the embeddings!
>
> **The magic:** q(pos_m) · k(pos_n) depends only on (m - n), the relative position!

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """
    Rotary Position Embeddings (RoPE).
    
    Used in LLaMA, Mistral, and many modern LLMs.
    
    Key insight: Apply rotation to Q and K, making the dot product
    naturally depend on relative position.
    """
    
    def __init__(self, dim: int, max_len: int = 2048, base: int = 10000):
        """
        Args:
            dim: Dimension of the embeddings (must be even)
            max_len: Maximum sequence length
            base: Base for the frequency calculation
        """
        super().__init__()
        
        assert dim % 2 == 0, "Dimension must be even for RoPE"
        
        self.dim = dim
        self.max_len = max_len
        self.base = base
        
        # Precompute the frequency for each dimension pair
        # freq_i = 1 / (base^(2i/dim))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # Precompute sin and cos for all positions
        self._build_cache(max_len)
        
    def _build_cache(self, seq_len: int):
        """Build sin/cos cache for given sequence length."""
        # Position indices
        t = torch.arange(seq_len, device=self.inv_freq.device)
        
        # Outer product: (seq_len,) x (dim/2,) -> (seq_len, dim/2)
        freqs = torch.outer(t, self.inv_freq)
        
        # Duplicate for pairs: (seq_len, dim)
        emb = torch.cat([freqs, freqs], dim=-1)
        
        # Cache sin and cos
        self.register_buffer('cos_cached', emb.cos(), persistent=False)
        self.register_buffer('sin_cached', emb.sin(), persistent=False)
        
    def forward(self, x, seq_len: int = None):
        """
        Get cos and sin values for rotary embedding.
        
        Args:
            x: Input tensor (used for device/dtype)
            seq_len: Sequence length
            
        Returns:
            cos, sin: Tensors of shape (seq_len, dim)
        """
        if seq_len is None:
            seq_len = x.size(1)
            
        if seq_len > self.max_len:
            # Extend cache if needed
            self._build_cache(seq_len)
            
        return (
            self.cos_cached[:seq_len].to(x.dtype),
            self.sin_cached[:seq_len].to(x.dtype)
        )


def rotate_half(x):
    """Rotate half the hidden dims of x."""
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat([-x2, x1], dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
    """
    Apply rotary position embedding to queries and keys.
    
    Args:
        q: Query tensor (batch, heads, seq_len, dim)
        k: Key tensor (batch, heads, seq_len, dim)
        cos: Cosine values (seq_len, dim)
        sin: Sine values (seq_len, dim)
        
    Returns:
        Rotated q and k
    """
    # Reshape cos/sin for broadcasting
    cos = cos.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, dim)
    sin = sin.unsqueeze(0).unsqueeze(0)
    
    # Apply rotation
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    
    return q_embed, k_embed


# Demonstrate RoPE
dim = 64
seq_len = 10
batch_size = 2
num_heads = 4

rope = RotaryPositionalEmbedding(dim)

# Create queries and keys
q = torch.randn(batch_size, num_heads, seq_len, dim)
k = torch.randn(batch_size, num_heads, seq_len, dim)

# Get cos/sin
cos, sin = rope(q, seq_len)

# Apply RoPE
q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)

print(f"Original Q shape: {q.shape}")
print(f"Rotated Q shape:  {q_rot.shape}")
print(f"cos/sin shape:    {cos.shape}")

In [None]:
# Visualize RoPE rotation

def visualize_rope():
    """Visualize how RoPE rotates vectors."""
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # 2D visualization of rotation
    ax = axes[0]
    dim = 2
    rope = RotaryPositionalEmbedding(dim, max_len=20)
    
    # Original vector
    v = torch.tensor([[[[1.0, 0.0]]]])  # Unit vector pointing right
    
    colors = plt.cm.viridis(np.linspace(0, 1, 10))
    
    for pos in range(10):
        cos, sin = rope(v, pos + 1)
        cos, sin = cos[-1:], sin[-1:]  # Last position
        v_rot, _ = apply_rotary_pos_emb(v, v, cos, sin)
        
        ax.arrow(0, 0, v_rot[0, 0, 0, 0].item(), v_rot[0, 0, 0, 1].item(),
                head_width=0.05, head_length=0.05, fc=colors[pos], ec=colors[pos],
                label=f'pos {pos}' if pos % 2 == 0 else None)
    
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.set_aspect('equal')
    ax.set_title('RoPE: Vector Rotation by Position')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    # Show attention pattern with RoPE
    ax = axes[1]
    dim = 64
    seq_len = 20
    rope = RotaryPositionalEmbedding(dim)
    
    # Create Q and K (same values to isolate position effect)
    q = torch.ones(1, 1, seq_len, dim)
    k = torch.ones(1, 1, seq_len, dim)
    
    cos, sin = rope(q, seq_len)
    q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)
    
    # Compute attention scores
    scores = torch.matmul(q_rot, k_rot.transpose(-2, -1)) / math.sqrt(dim)
    
    im = ax.imshow(scores[0, 0].detach().numpy(), cmap='Blues')
    ax.set_xlabel('Key Position')
    ax.set_ylabel('Query Position')
    ax.set_title('Attention Scores with RoPE\n(diagonal = self, off-diagonal = relative pos)')
    plt.colorbar(im, ax=ax)
    
    # Relative position effect
    ax = axes[2]
    # Show that score depends only on relative position
    relative_scores = []
    for i in range(seq_len):
        relative_scores.append(scores[0, 0, i, :].detach().numpy())
    
    # Plot score vs relative position for different query positions
    for i in [0, 5, 10, 15]:
        rel_pos = np.arange(seq_len) - i
        ax.plot(rel_pos, relative_scores[i], label=f'query pos {i}', alpha=0.7)
    
    ax.set_xlabel('Relative Position (key - query)')
    ax.set_ylabel('Attention Score')
    ax.set_title('RoPE: Scores Depend on Relative Position')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

visualize_rope()

### RoPE Advantages:

1. **Relative position encoding**: Attention naturally depends on relative position
2. **No extra parameters**: Just rotation, no learned embeddings needed
3. **Length extrapolation**: Can extend to longer sequences than training (with techniques like YaRN)
4. **Efficient**: Applied only to Q and K, not added to embeddings

---

## Part 6: ALiBi (Attention with Linear Biases)

### ELI5: ALiBi

> **Instead of modifying the embeddings, ALiBi adds a penalty to attention scores based on distance.**
>
> "The farther apart two words are, the less they should attend to each other."
>
> It's like saying: "Far-away words get a small penalty, nearby words don't."

In [None]:
class ALiBi(nn.Module):
    """
    Attention with Linear Biases.
    
    Adds a linear bias based on relative position to attention scores.
    Used in BLOOM, MPT, and other models.
    
    score(i, j) = q_i · k_j - m * |i - j|
    
    where m is a per-head slope.
    """
    
    def __init__(self, num_heads: int):
        """
        Args:
            num_heads: Number of attention heads
        """
        super().__init__()
        
        self.num_heads = num_heads
        
        # Compute slopes for each head
        # Slopes are geometric sequence: 2^(-8/n), 2^(-16/n), ...
        slopes = self._get_slopes(num_heads)
        self.register_buffer('slopes', torch.tensor(slopes).float())
        
    def _get_slopes(self, num_heads: int):
        """
        Get the slope values for each head.
        
        Returns geometric sequence for power-of-2 num_heads,
        interpolated sequence otherwise.
        """
        def get_slopes_power_of_2(n):
            start = 2 ** (-8 / n)
            ratio = start
            return [start * (ratio ** i) for i in range(n)]
        
        if math.log2(num_heads).is_integer():
            return get_slopes_power_of_2(num_heads)
        else:
            # For non-power-of-2, use closest power of 2 and interpolate
            closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
            slopes = get_slopes_power_of_2(closest_power_of_2)
            extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[0::2]
            return slopes + extra_slopes[:num_heads - closest_power_of_2]
    
    def forward(self, seq_len: int):
        """
        Compute ALiBi bias matrix.
        
        Args:
            seq_len: Sequence length
            
        Returns:
            Bias tensor of shape (num_heads, seq_len, seq_len)
        """
        # Create relative position matrix
        # For causal attention, we only look at positions <= current
        positions = torch.arange(seq_len)
        relative_pos = positions.unsqueeze(0) - positions.unsqueeze(1)  # (seq, seq)
        
        # For causal: upper triangle should be -inf (future positions)
        # Lower triangle: use negative distance as bias
        bias = -torch.abs(relative_pos).float()  # (seq, seq)
        
        # Multiply by slopes for each head
        bias = bias.unsqueeze(0) * self.slopes.unsqueeze(-1).unsqueeze(-1)
        # (num_heads, seq, seq)
        
        return bias

# Visualize ALiBi
num_heads = 8
seq_len = 20
alibi = ALiBi(num_heads)

bias = alibi(seq_len)
print(f"ALiBi bias shape: {bias.shape}")
print(f"Slopes: {alibi.slopes.tolist()}")

In [None]:
# Visualize ALiBi patterns

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i, ax in enumerate(axes.flat):
    head_bias = bias[i].detach().numpy()
    
    im = ax.imshow(head_bias, cmap='RdBu', vmin=-5, vmax=0)
    ax.set_title(f'Head {i+1}\nSlope: {alibi.slopes[i]:.4f}')
    ax.set_xlabel('Key Position')
    ax.set_ylabel('Query Position')

plt.suptitle('ALiBi Bias Patterns per Head\n(Darker = More Penalty for Distance)', fontsize=12)
plt.tight_layout()
plt.show()

print("\nNote: Each head has a different slope, so different heads attend to different ranges!")
print("Head 1 (large slope): Strong local attention")
print("Head 8 (small slope): Weak distance penalty, more global attention")

### ALiBi Advantages:

1. **No learned parameters**: Just a fixed bias
2. **Excellent extrapolation**: Works well on sequences longer than training
3. **Simple implementation**: Just add bias to attention scores
4. **Multi-scale attention**: Different heads attend at different ranges

---

## Part 7: Comparison of Methods

In [None]:
# Comprehensive comparison

comparison = {
    "Method": ["Sinusoidal", "Learned", "RoPE", "ALiBi"],
    "Parameters": ["0", "max_len × d_model", "0", "0"],
    "Applied to": ["Embeddings", "Embeddings", "Q and K", "Attention scores"],
    "Position type": ["Absolute", "Absolute", "Relative", "Relative"],
    "Extrapolation": ["Limited", "Poor", "Good (with NTK/YaRN)", "Excellent"],
    "Used in": ["Original Transformer", "GPT-2, BERT", "LLaMA, Mistral", "BLOOM, MPT"]
}

print("Positional Encoding Comparison:")
print("=" * 80)

# Print as table
max_lens = [max(len(str(v)) for v in [k] + vals) for k, vals in comparison.items()]

# Header
header = " | ".join(k.ljust(l) for k, l in zip(comparison.keys(), max_lens))
print(header)
print("-" * len(header))

# Rows
for i in range(len(comparison["Method"])):
    row = " | ".join(str(comparison[k][i]).ljust(l) for k, l in zip(comparison.keys(), max_lens))
    print(row)

In [None]:
# Test extrapolation

def test_extrapolation():
    """Test how different methods handle longer sequences than training."""
    
    d_model = 64
    train_len = 128
    test_lens = [64, 128, 256, 512, 1024]
    
    results = {"Sinusoidal": [], "Learned": [], "RoPE": []}
    
    # Create models "trained" on max_len=128
    sinusoidal = SinusoidalPositionalEncoding(d_model, max_len=train_len)
    learned = LearnedPositionalEmbedding(d_model, max_len=train_len)
    rope = RotaryPositionalEmbedding(d_model, max_len=train_len)
    
    print(f"Testing extrapolation (trained on max_len={train_len}):")
    print("=" * 50)
    
    for test_len in test_lens:
        # Sinusoidal can extrapolate
        try:
            # Extend the buffer if needed
            if test_len > sinusoidal.pe.size(1):
                # Recompute for longer sequence
                new_sinusoidal = SinusoidalPositionalEncoding(d_model, max_len=test_len)
                x = torch.zeros(1, test_len, d_model)
                _ = new_sinusoidal(x)
            else:
                x = torch.zeros(1, test_len, d_model)
                _ = sinusoidal(x)
            results["Sinusoidal"].append("✓")
        except Exception as e:
            results["Sinusoidal"].append(f"✗ ({type(e).__name__})")
        
        # Learned cannot extrapolate beyond max_len
        try:
            x = torch.zeros(1, test_len, d_model)
            _ = learned(x)
            results["Learned"].append("✓")
        except Exception as e:
            results["Learned"].append(f"✗ ({type(e).__name__})")
        
        # RoPE can extrapolate (cache is extended)
        try:
            x = torch.zeros(1, 1, test_len, d_model)
            cos, sin = rope(x, test_len)
            results["RoPE"].append("✓")
        except Exception as e:
            results["RoPE"].append(f"✗ ({type(e).__name__})")
    
    print(f"\n{'Length':<10} {'Sinusoidal':<15} {'Learned':<20} {'RoPE':<15}")
    print("-" * 60)
    for i, test_len in enumerate(test_lens):
        print(f"{test_len:<10} {results['Sinusoidal'][i]:<15} {results['Learned'][i]:<20} {results['RoPE'][i]:<15}")

test_extrapolation()

---

## Try It Yourself: Exercises

### Exercise 1: Implement YaRN (Yet another RoPE extensioN)

YaRN modifies RoPE to better handle very long sequences. The key idea is to scale the rotary frequencies.

<details>
<summary>Hint</summary>
Modify the frequency computation: `inv_freq = inv_freq / scale`
where scale depends on the ratio of target length to training length.
</details>

In [None]:
class YaRNPositionalEmbedding(nn.Module):
    """
    YaRN: Yet another RoPE extensioN
    
    Extends RoPE to handle longer sequences through frequency scaling.
    
    TODO: Implement this!
    """
    
    def __init__(self, dim, max_len=2048, base=10000, scale=1.0):
        super().__init__()
        # YOUR CODE HERE
        pass
    
    def forward(self, x, seq_len=None):
        # YOUR CODE HERE
        pass

# Test your implementation:
# yarn = YaRNPositionalEmbedding(64, max_len=128, scale=2.0)
# x = torch.randn(1, 1, 256, 64)
# cos, sin = yarn(x, 256)

### Exercise 2: Relative Position Bias (T5 style)

T5 uses learned relative position biases. Implement a simplified version.

In [None]:
class T5RelativePositionBias(nn.Module):
    """
    Relative position bias used in T5.
    
    Learns a bias for each relative position bucket.
    
    TODO: Implement this!
    """
    
    def __init__(self, num_heads, num_buckets=32, max_distance=128):
        super().__init__()
        # YOUR CODE HERE
        # Hint: Use nn.Embedding for the biases
        pass
    
    def _relative_position_bucket(self, relative_position, num_buckets, max_distance):
        """Convert relative position to bucket index."""
        # YOUR CODE HERE
        pass
    
    def forward(self, seq_len):
        # YOUR CODE HERE
        pass

---

## Common Mistakes

### Mistake 1: Adding PE to the wrong thing

In [None]:
# Wrong: Adding position encoding after attention
def wrong_order(x, attention, pos_enc):
    x = attention(x, x, x)  # No position info!
    x = x + pos_enc  # Too late
    return x

# Right: Adding position encoding before attention
def right_order(x, attention, pos_enc):
    x = x + pos_enc  # Add position info first
    x = attention(x, x, x)  # Now attention sees position
    return x

print("Position encoding must be added BEFORE attention!")

### Mistake 2: Using absolute position with variable-length inputs

In [None]:
# Wrong: Same position for different sequence lengths means different meanings
# "The cat" -> positions [0, 1]
# "The big fluffy cat" -> positions [0, 1, 2, 3]
# Position 1 means "cat" in first, "big" in second!

print("Problem with absolute positions:")
print("  'The cat'           -> pos[0]=The, pos[1]=cat")
print("  'The big fluffy cat' -> pos[0]=The, pos[1]=big, pos[2]=fluffy, pos[3]=cat")
print("\n  Position 1 has different meanings!")
print("\nSolution: Use relative position methods (RoPE, ALiBi) for better generalization.")

### Mistake 3: Forgetting to extend RoPE cache

In [None]:
# Wrong: Not checking sequence length
class WrongRoPE(nn.Module):
    def __init__(self, dim, max_len):
        super().__init__()
        self.cos = torch.randn(max_len, dim)
        self.sin = torch.randn(max_len, dim)
        
    def forward(self, seq_len):
        return self.cos[:seq_len], self.sin[:seq_len]  # Error if seq_len > max_len!

# Right: Extend cache dynamically
# (as shown in our RotaryPositionalEmbedding implementation)

print("Always handle sequences longer than the initial cache!")

---

## Checkpoint

You've learned:
- ✅ Why Transformers need position information
- ✅ Sinusoidal positional encoding (original Transformer)
- ✅ Learned positional embeddings (GPT-2, BERT)
- ✅ Rotary Position Embeddings (RoPE) - modern LLMs
- ✅ ALiBi - attention biases for long sequences
- ✅ Trade-offs between different approaches

---

## Challenge (Optional)

Implement **NTK-aware RoPE** which modifies the base frequency for better extrapolation:

```python
# Original: base = 10000
# NTK-aware: base = 10000 * (scale ** (dim / (dim - 2)))
```

Where `scale` is the ratio of target sequence length to training length.

In [None]:
# Your challenge implementation here

---

## Further Reading

- [RoFormer Paper](https://arxiv.org/abs/2104.09864) - Original RoPE paper
- [ALiBi Paper](https://arxiv.org/abs/2108.12409) - Attention with Linear Biases
- [YaRN Paper](https://arxiv.org/abs/2309.00071) - Yet another RoPE extensioN
- [Extending Context with RoPE](https://kaiokendev.github.io/context) - Great blog post

---

## Cleanup

In [None]:
import gc

# Clean up
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

print("Memory cleared! Ready for the next notebook.")

---

## Next Up

In **Notebook 04: Tokenization Lab**, we'll learn how text becomes numbers:
- Byte Pair Encoding (BPE)
- SentencePiece
- Comparing tokenizers from different models

---

*Excellent work! Position encodings are subtle but crucial for Transformer success.*