# Day 4: Positional Encodings in Transformer Models

In this notebook, we'll explore why transformers need positional information and implement different types of positional encodings, including sinusoidal encodings and Rotary Position Embeddings (RoPE).

## Setup and Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from math import pi, sin, cos

# Set style for plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 1. The Position Problem

Transformers process all tokens simultaneously, losing the natural sequence order information that RNNs and LSTMs have built-in. Let's demonstrate why this is a problem.

In [None]:
def demonstrate_position_problem():
    """Show why transformers need positional encoding."""
    
    # Simulate attention without position information
    def simple_attention(queries, keys, values):
        """Simplified attention mechanism."""
        # Compute attention scores
        scores = torch.matmul(queries, keys.transpose(-2, -1))
        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, values)
        return output, attention_weights
    
    # Create identical embeddings for different positions
    embedding_dim = 4
    seq_len = 3
    
    # Same word at different positions
    word_embedding = torch.tensor([1.0, 0.5, -0.2, 0.8])
    
    # Without positional encoding - all positions identical
    embeddings_no_pos = word_embedding.unsqueeze(0).repeat(seq_len, 1)
    print("Embeddings without positional encoding:")
    print(embeddings_no_pos)
    
    # Attention treats all positions identically
    output_no_pos, weights_no_pos = simple_attention(
        embeddings_no_pos.unsqueeze(0),
        embeddings_no_pos.unsqueeze(0), 
        embeddings_no_pos.unsqueeze(0)
    )
    
    print("\nAttention weights without position (should be uniform):")
    print(weights_no_pos[0])
    
    # Let's visualize the attention pattern
    plt.figure(figsize=(8, 6))
    plt.imshow(weights_no_pos[0].detach().numpy(), cmap='Blues')
    plt.title('Attention Pattern Without Position Information')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.colorbar(label='Attention Weight')
    plt.show()
    
    return embeddings_no_pos, weights_no_pos

# Demonstrate the position problem
embeddings_no_pos, weights_no_pos = demonstrate_position_problem()

### Understanding the Problem

As we can see above, when we have the same word repeated at different positions, the attention mechanism treats all positions identically. This is because the attention calculation is based solely on the dot product between query and key vectors, which doesn't consider position information.

In natural language, word order is crucial. "Dog bites man" means something very different from "Man bites dog." Without position information, transformers would treat these sentences as identical bags of words.

## 2. Sinusoidal Positional Encoding

The original Transformer paper introduced sinusoidal positional encodings using sine and cosine functions of different frequencies. Let's implement and visualize them.

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding implementation."""
    
    def __init__(self, d_model, max_seq_len=5000):
        super().__init__()
        self.d_model = d_model
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        
        # Create division term for frequency scaling
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-np.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        """Add positional encoding to input embeddings."""
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len]
    
    def get_encoding(self, position):
        """Get positional encoding for specific position."""
        return self.pe[0, position].detach().numpy()

### Visualizing Sinusoidal Positional Encodings

Let's create and visualize the sinusoidal positional encodings:

In [None]:
# Create positional encoding
d_model = 128
max_seq_len = 100
pos_encoder = SinusoidalPositionalEncoding(d_model, max_seq_len)

def visualize_positional_encodings(pos_encoder, max_pos=50, max_dim=64):
    """Visualize positional encoding patterns."""
    
    # Get encodings for different positions
    encodings = []
    for pos in range(max_pos):
        encoding = pos_encoder.get_encoding(pos)[:max_dim]
        encodings.append(encoding)
    
    encodings = np.array(encodings)
    
    # Create heatmap
    plt.figure(figsize=(12, 8))
    plt.imshow(encodings.T, cmap='RdBu', aspect='auto')
    plt.colorbar(label='Encoding Value')
    plt.xlabel('Position')
    plt.ylabel('Embedding Dimension')
    plt.title('Sinusoidal Positional Encodings')
    plt.show()
    
    # Show specific positions
    positions_to_show = [0, 1, 5, 10, 20]
    plt.figure(figsize=(12, 6))
    
    for pos in positions_to_show:
        encoding = pos_encoder.get_encoding(pos)[:32]  # Show first 32 dims
        plt.plot(encoding, label=f'Position {pos}', alpha=0.7)
    
    plt.xlabel('Embedding Dimension')
    plt.ylabel('Encoding Value')
    plt.title('Positional Encodings for Different Positions')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return encodings

# Visualize the encodings
encodings = visualize_positional_encodings(pos_encoder)

### Analyzing Properties of Sinusoidal Encodings

Let's analyze some key properties of sinusoidal positional encodings:

In [None]:
def analyze_sinusoidal_properties(pos_encoder, max_pos=20):
    """Analyze key properties of sinusoidal positional encodings."""
    
    print("Properties of Sinusoidal Positional Encodings:")
    print("=" * 50)
    
    # Property 1: Unique encoding for each position
    encodings = []
    for pos in range(max_pos):
        encoding = pos_encoder.get_encoding(pos)
        encodings.append(encoding)
    
    encodings = np.array(encodings)
    
    # Check uniqueness
    unique_count = len(np.unique(encodings, axis=0))
    print(f"1. Uniqueness: {unique_count}/{max_pos} positions have unique encodings")
    
    # Property 2: Relative position information
    def relative_position_similarity(pos1, pos2):
        enc1 = pos_encoder.get_encoding(pos1)
        enc2 = pos_encoder.get_encoding(pos2)
        return np.dot(enc1, enc2) / (np.linalg.norm(enc1) * np.linalg.norm(enc2))
    
    print("\n2. Relative Position Similarities:")
    reference_pos = 5
    for offset in [1, 2, 5, 10]:
        sim = relative_position_similarity(reference_pos, reference_pos + offset)
        print(f"   Position {reference_pos} vs {reference_pos + offset}: {sim:.4f}")
    
    # Property 3: Linear combination property
    # PE(pos + k) can be expressed as linear combination of PE(pos) and PE(k)
    pos_a, pos_b = 3, 7
    pos_sum = pos_a + pos_b
    
    enc_a = pos_encoder.get_encoding(pos_a)
    enc_b = pos_encoder.get_encoding(pos_b)
    enc_sum = pos_encoder.get_encoding(pos_sum)
    
    # This property is approximate for sinusoidal encodings
    print(f"\n3. Additivity (approximate):")
    print(f"   ||PE({pos_a}) + PE({pos_b}) - PE({pos_sum})|| = {np.linalg.norm(enc_a + enc_b - enc_sum):.4f}")
    
    # Property 4: Fixed norm
    norms = np.linalg.norm(encodings, axis=1)
    print(f"\n4. Norm consistency:")
    print(f"   Mean norm: {np.mean(norms):.4f}")
    print(f"   Std dev of norms: {np.std(norms):.4f}")
    
    # Property 5: Frequency spectrum
    print(f"\n5. Frequency spectrum:")
    # Compute average frequency for each dimension
    freq_by_dim = np.zeros(pos_encoder.d_model)
    for dim in range(pos_encoder.d_model):
        # Estimate frequency by counting zero crossings
        zero_crossings = np.sum(np.diff(np.signbit(encodings[:, dim])))
        freq_by_dim[dim] = zero_crossings / (2 * max_pos)
    
    # Plot frequency spectrum
    plt.figure(figsize=(10, 5))
    plt.plot(freq_by_dim)
    plt.title('Frequency Spectrum of Sinusoidal Encodings')
    plt.xlabel('Dimension')
    plt.ylabel('Frequency (estimated)')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return encodings

# Analyze properties
properties = analyze_sinusoidal_properties(pos_encoder)

### Demonstrating Positional Encoding in Action

Let's see how positional encoding changes the attention pattern for identical tokens:

In [None]:
def demonstrate_positional_encoding_effect():
    """Demonstrate how positional encoding affects attention."""
    
    # Simplified attention function
    def simple_attention(queries, keys, values):
        """Simplified attention mechanism."""
        # Compute attention scores
        scores = torch.matmul(queries, keys.transpose(-2, -1))
        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, values)
        return output, attention_weights
    
    # Setup
    embedding_dim = 32
    seq_len = 5
    pos_encoder = SinusoidalPositionalEncoding(embedding_dim, 100)
    
    # Create identical embeddings (same word repeated)
    word_embedding = torch.randn(embedding_dim)
    embeddings_no_pos = word_embedding.unsqueeze(0).repeat(seq_len, 1)
    
    # Add positional encoding
    embeddings_with_pos = pos_encoder(embeddings_no_pos.unsqueeze(0)).squeeze(0)
    
    # Compute attention without positional encoding
    output_no_pos, attn_no_pos = simple_attention(
        embeddings_no_pos.unsqueeze(0),
        embeddings_no_pos.unsqueeze(0),
        embeddings_no_pos.unsqueeze(0)
    )
    
    # Compute attention with positional encoding
    output_with_pos, attn_with_pos = simple_attention(
        embeddings_with_pos.unsqueeze(0),
        embeddings_with_pos.unsqueeze(0),
        embeddings_with_pos.unsqueeze(0)
    )
    
    # Visualize attention patterns
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Without positional encoding
    im1 = ax1.imshow(attn_no_pos[0].detach().numpy(), cmap='Blues')
    ax1.set_title('Attention WITHOUT Positional Encoding')
    ax1.set_xlabel('Key Position')
    ax1.set_ylabel('Query Position')
    plt.colorbar(im1, ax=ax1)
    
    # With positional encoding
    im2 = ax2.imshow(attn_with_pos[0].detach().numpy(), cmap='Blues')
    ax2.set_title('Attention WITH Positional Encoding')
    ax2.set_xlabel('Key Position')
    ax2.set_ylabel('Query Position')
    plt.colorbar(im2, ax=ax2)
    
    plt.tight_layout()
    plt.show()
    
    print("Without positional encoding, attention is uniform across all positions.")
    print("With positional encoding, attention patterns become position-aware.")
    print("Notice how positions tend to attend more to themselves and nearby positions.")
    
    return embeddings_no_pos, embeddings_with_pos, attn_no_pos, attn_with_pos

# Demonstrate the effect
embeddings_no_pos, embeddings_with_pos, attn_no_pos, attn_with_pos = demonstrate_positional_encoding_effect()