# Day 4: Positional Encodings - Part 2

In this notebook, we'll explore Rotary Position Embedding (RoPE) and learned positional embeddings.

## Setup and Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from math import pi, sin, cos

# Set style for plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 3. Rotary Position Embedding (RoPE)

RoPE is a modern positional encoding that applies rotation to query and key vectors, providing better relative position modeling.

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE) implementation."""
    
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        
        # Precompute frequency inverse
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # Precompute positional encodings
        self._precompute_freqs_cis(max_seq_len)
    
    def _precompute_freqs_cis(self, seq_len):
        """Precompute complex exponentials for efficiency."""
        t = torch.arange(seq_len, dtype=torch.float32)
        freqs = torch.outer(t, self.inv_freq)
        
        # Create complex exponentials
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        self.register_buffer('freqs_cis', freqs_cis)
    
    def _reshape_for_broadcast(self, freqs_cis, x):
        """Reshape frequency tensor for broadcasting."""
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)
    
    def forward(self, x, start_pos=0):
        """Apply rotary position embedding."""
        seq_len = x.shape[1]
        freqs_cis = self.freqs_cis[start_pos:start_pos + seq_len]
        
        # Convert to complex representation
        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        
        # Reshape for broadcasting
        freqs_cis = self._reshape_for_broadcast(freqs_cis, x_complex)
        
        # Apply rotation
        x_rotated = x_complex * freqs_cis
        
        # Convert back to real representation
        x_out = torch.view_as_real(x_rotated).flatten(3)
        
        return x_out.type_as(x)

### Demonstrating RoPE

Let's demonstrate how RoPE works and how it affects attention patterns:

In [None]:
def demonstrate_rope():
    """Demonstrate RoPE properties."""
    
    batch_size, seq_len, d_model = 2, 8, 64
    rope = RotaryPositionalEmbedding(d_model)
    
    # Create sample query and key vectors
    queries = torch.randn(batch_size, seq_len, d_model)
    keys = torch.randn(batch_size, seq_len, d_model)
    
    # Apply RoPE
    queries_rope = rope(queries)
    keys_rope = rope(keys)
    
    print("RoPE Demonstration:")
    print(f"Original queries shape: {queries.shape}")
    print(f"RoPE queries shape: {queries_rope.shape}")
    
    # Show that RoPE preserves relative positions
    def compute_attention_pattern(q, k):
        """Compute attention pattern."""
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_model)
        return torch.softmax(scores, dim=-1)
    
    # Attention without RoPE
    attn_no_rope = compute_attention_pattern(queries[0], keys[0])
    
    # Attention with RoPE
    attn_with_rope = compute_attention_pattern(queries_rope[0], keys_rope[0])
    
    print(f"\nAttention pattern difference: {torch.mean(torch.abs(attn_no_rope - attn_with_rope)):.4f}")
    
    # Visualize attention patterns
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Without RoPE
    im1 = ax1.imshow(attn_no_rope.detach().numpy(), cmap='Blues')
    ax1.set_title('Attention WITHOUT RoPE')
    ax1.set_xlabel('Key Position')
    ax1.set_ylabel('Query Position')
    plt.colorbar(im1, ax=ax1)
    
    # With RoPE
    im2 = ax2.imshow(attn_with_rope.detach().numpy(), cmap='Blues')
    ax2.set_title('Attention WITH RoPE')
    ax2.set_xlabel('Key Position')
    ax2.set_ylabel('Query Position')
    plt.colorbar(im2, ax=ax2)
    
    plt.tight_layout()
    plt.show()
    
    return queries_rope, keys_rope, attn_no_rope, attn_with_rope

# Demonstrate RoPE
rope_results = demonstrate_rope()

### Visualizing RoPE Rotations

Let's visualize how RoPE rotates vectors based on their position:

In [None]:
def visualize_rope_rotations():
    """Visualize how RoPE rotates vectors based on position."""
    
    # Create a simple 2D vector for visualization
    vector = torch.tensor([[1.0, 0.0]])
    
    # Create rotation matrices for different positions
    positions = range(8)
    rotated_vectors = []
    
    for pos in positions:
        # Create rotation matrix for this position
        theta = pos * 0.1  # Simple rotation angle proportional to position
        rotation_matrix = torch.tensor([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])
        
        # Apply rotation
        rotated = torch.matmul(vector, rotation_matrix)
        rotated_vectors.append(rotated.numpy()[0])
    
    # Visualize rotations
    plt.figure(figsize=(8, 8))
    
    # Plot unit circle
    theta = np.linspace(0, 2*np.pi, 100)
    plt.plot(np.cos(theta), np.sin(theta), 'k--', alpha=0.3)
    
    # Plot vectors
    colors = plt.cm.viridis(np.linspace(0, 1, len(positions)))
    
    for i, (pos, vec) in enumerate(zip(positions, rotated_vectors)):
        plt.arrow(0, 0, vec[0], vec[1], head_width=0.05, head_length=0.1, 
                 fc=colors[i], ec=colors[i], label=f'Position {pos}')
    
    plt.xlim(-1.2, 1.2)
    plt.ylim(-1.2, 1.2)
    plt.grid(True, alpha=0.3)
    plt.axhline(y=0, color='k', linewidth=0.5, alpha=0.5)
    plt.axvline(x=0, color='k', linewidth=0.5, alpha=0.5)
    plt.title('RoPE Vector Rotations by Position')
    plt.legend()
    plt.axis('equal')
    plt.show()
    
    return rotated_vectors

# Visualize RoPE rotations
rotated_vectors = visualize_rope_rotations()

## 4. Learned Positional Embeddings

Some models use learned positional embeddings instead of fixed mathematical functions.

In [None]:
class LearnedPositionalEmbedding(nn.Module):
    """Learned positional embedding implementation."""
    
    def __init__(self, max_seq_len, d_model):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.d_model = d_model
        
        # Learnable position embeddings
        self.position_embeddings = nn.Embedding(max_seq_len, d_model)
        
        # Initialize with small random values
        nn.init.normal_(self.position_embeddings.weight, std=0.02)
    
    def forward(self, x, position_ids=None):
        """Add learned positional embeddings."""
        seq_len = x.size(1)
        
        if position_ids is None:
            position_ids = torch.arange(seq_len, dtype=torch.long, device=x.device)
            position_ids = position_ids.unsqueeze(0).expand(x.size(0), -1)
        
        position_embeddings = self.position_embeddings(position_ids)
        return x + position_embeddings
    
    def get_embedding(self, position):
        """Get positional embedding for specific position."""
        return self.position_embeddings.weight[position].detach().numpy()

### Comparing Different Positional Encoding Methods

Let's compare sinusoidal and learned positional encodings:

In [None]:
def compare_positional_encodings():
    """Compare different positional encoding approaches."""
    
    d_model = 64
    max_seq_len = 20
    
    # Initialize different encoders
    sinusoidal = SinusoidalPositionalEncoding(d_model, max_seq_len)
    learned = LearnedPositionalEmbedding(max_seq_len, d_model)
    
    # Get encodings for comparison
    positions = range(10)
    
    sin_encodings = []
    learned_encodings = []
    
    for pos in positions:
        sin_enc = sinusoidal.get_encoding(pos)
        learned_enc = learned.get_embedding(pos)
        
        sin_encodings.append(sin_enc)
        learned_encodings.append(learned_enc)
    
    sin_encodings = np.array(sin_encodings)
    learned_encodings = np.array(learned_encodings)
    
    # Visualize comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Sinusoidal encodings
    im1 = ax1.imshow(sin_encodings.T, cmap='RdBu', aspect='auto')
    ax1.set_title('Sinusoidal Positional Encodings')
    ax1.set_xlabel('Position')
    ax1.set_ylabel('Embedding Dimension')
    plt.colorbar(im1, ax=ax1)
    
    # Learned encodings (random initialization)
    im2 = ax2.imshow(learned_encodings.T, cmap='RdBu', aspect='auto')
    ax2.set_title('Learned Positional Embeddings (Random Init)')
    ax2.set_xlabel('Position')
    ax2.set_ylabel('Embedding Dimension')
    plt.colorbar(im2, ax=ax2)
    
    plt.tight_layout()
    plt.show()
    
    # Compare norms
    sin_norms = np.linalg.norm(sin_encodings, axis=1)
    learned_norms = np.linalg.norm(learned_encodings, axis=1)
    
    plt.figure(figsize=(10, 5))
    plt.plot(positions, sin_norms, 'o-', label='Sinusoidal')
    plt.plot(positions, learned_norms, 'o-', label='Learned')
    plt.title('Embedding Norms by Position')
    plt.xlabel('Position')
    plt.ylabel('L2 Norm')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return sin_encodings, learned_encodings

# For this to work, we need to define SinusoidalPositionalEncoding class
# Let's assume it's defined in the same way as in part 1
class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding implementation."""
    
    def __init__(self, d_model, max_seq_len=5000):
        super().__init__()
        self.d_model = d_model
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        
        # Create division term for frequency scaling
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-np.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        """Add positional encoding to input embeddings."""
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len]
    
    def get_encoding(self, position):
        """Get positional encoding for specific position."""
        return self.pe[0, position].detach().numpy()

# Compare encodings
comparison_results = compare_positional_encodings()