# Week 5: Build Your Own Transformer - Interactive Lab

## Learning Objectives
1. Implement self-attention from scratch
2. Build multi-head attention
3. Create a complete transformer block
4. Compare with RNN performance
5. Visualize attention patterns

---

In [None]:
# Setup and imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional
import time
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

## Part 1: Understanding Self-Attention

### 1.1 The Math Behind Attention

The attention formula:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Let's implement it step by step!

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query matrix [batch_size, seq_len, d_k]
        K: Key matrix [batch_size, seq_len, d_k]
        V: Value matrix [batch_size, seq_len, d_v]
        mask: Optional mask [batch_size, seq_len, seq_len]
    
    Returns:
        output: Attention output [batch_size, seq_len, d_v]
        attention_weights: Attention weights [batch_size, seq_len, seq_len]
    """
    d_k = Q.size(-1)
    
    # Step 1: Compute QK^T
    scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # Step 3: Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 4: Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5: Multiply by values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Test with a simple example
batch_size, seq_len, d_model = 1, 4, 8
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

### 1.2 Visualizing Attention Patterns

In [None]:
def visualize_attention(attention_weights, words=None, title="Attention Weights"):
    """
    Visualize attention weights as a heatmap.
    """
    # Get the first batch item
    weights = attention_weights[0].detach().cpu().numpy()
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(weights, cmap='YlOrRd', cbar=True, square=True,
                xticklabels=words, yticklabels=words,
                vmin=0, vmax=1, annot=True, fmt='.2f')
    plt.title(title)
    plt.xlabel('Keys')
    plt.ylabel('Queries')
    plt.tight_layout()
    plt.show()

# Example sentence
words = ['The', 'cat', 'sat', 'mat']
visualize_attention(weights, words, "Self-Attention: Each word attends to all words")

## Part 2: Multi-Head Attention

### 2.1 Why Multiple Heads?
Different heads can focus on different types of relationships!

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # 1. Linear projections in batch from d_model => h x d_k
        Q = self.W_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2. Apply attention on all the projected vectors in batch
        attn_output, attn_weights = self.attention(Q, K, V, mask)
        
        # 3. Concatenate heads and put through final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        output = self.W_o(attn_output)
        return output, attn_weights
    
    def attention(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

# Test multi-head attention
d_model = 512
n_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model, n_heads)
x = torch.randn(batch_size, seq_len, d_model)
output, attn_weights = mha(x, x, x)

print(f"Multi-head output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"Each head dimension: {d_model // n_heads}")

### 2.2 Visualize Different Heads

In [None]:
def visualize_multihead_attention(attn_weights, n_heads_to_show=4):
    """
    Visualize attention patterns from multiple heads.
    """
    batch_size, n_heads, seq_len, _ = attn_weights.shape
    
    fig, axes = plt.subplots(1, n_heads_to_show, figsize=(15, 4))
    
    for head_idx in range(n_heads_to_show):
        weights = attn_weights[0, head_idx].detach().cpu().numpy()
        
        im = axes[head_idx].imshow(weights, cmap='Blues', aspect='auto', vmin=0, vmax=1)
        axes[head_idx].set_title(f'Head {head_idx + 1}')
        axes[head_idx].set_xlabel('Position')
        if head_idx == 0:
            axes[head_idx].set_ylabel('Position')
    
    plt.colorbar(im, ax=axes.ravel().tolist(), fraction=0.02)
    plt.suptitle('Different Attention Heads Focus on Different Patterns', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_multihead_attention(attn_weights)

## Part 3: Complete Transformer Block

### 3.1 Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-np.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]

# Visualize positional encoding
def visualize_positional_encoding(d_model=128, max_len=100):
    pe = PositionalEncoding(d_model, max_len)
    encoding = pe.pe[:max_len, 0, :].numpy()
    
    plt.figure(figsize=(12, 4))
    plt.imshow(encoding.T, cmap='RdBu', aspect='auto')
    plt.colorbar()
    plt.xlabel('Position')
    plt.ylabel('Dimension')
    plt.title('Positional Encoding Pattern (Sinusoidal)')
    plt.tight_layout()
    plt.show()

visualize_positional_encoding()

### 3.2 Complete Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(TransformerBlock, self).__init__()
        
        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output, attn_weights = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        
        return x, attn_weights

# Test transformer block
d_model = 512
n_heads = 8
d_ff = 2048
seq_len = 20
batch_size = 4

transformer_block = TransformerBlock(d_model, n_heads, d_ff)
x = torch.randn(batch_size, seq_len, d_model)
output, attn_weights = transformer_block(x)

print(f"Transformer block output shape: {output.shape}")
print(f"Parameters: {sum(p.numel() for p in transformer_block.parameters()):,}")

## Part 4: Build a Mini Language Model

### 4.1 Complete Transformer Model

In [None]:
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=4, 
                 d_ff=1024, max_len=100, dropout=0.1):
        super(MiniTransformer, self).__init__()
        
        # Embeddings
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        # Output layer
        self.ln_f = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
    
    def forward(self, x, mask=None):
        # Token embeddings and positional encoding
        x = self.embedding(x) * np.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Pass through transformer blocks
        attention_weights = []
        for transformer in self.transformer_blocks:
            x, attn_w = transformer(x, mask)
            attention_weights.append(attn_w)
        
        # Final layer norm and output projection
        x = self.ln_f(x)
        output = self.fc_out(x)
        
        return output, attention_weights

# Create a mini transformer
vocab_size = 10000
model = MiniTransformer(vocab_size).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Mini Transformer Parameters: {total_params:,}")
print(f"That's {total_params/1e6:.2f}M parameters!")

### 4.2 Generate Text with Your Transformer

In [None]:
def generate_text(model, start_tokens, max_length=50, temperature=1.0):
    """
    Generate text using the transformer model.
    """
    model.eval()
    generated = start_tokens.clone()
    
    with torch.no_grad():
        for _ in range(max_length):
            # Get predictions
            outputs, _ = model(generated)
            
            # Get the last token predictions
            next_token_logits = outputs[:, -1, :] / temperature
            
            # Sample from the distribution
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to generated sequence
            generated = torch.cat([generated, next_token], dim=1)
    
    return generated

# Example generation (with random weights - not trained)
start_tokens = torch.randint(0, vocab_size, (1, 5)).to(device)
generated_sequence = generate_text(model, start_tokens, max_length=20)
print(f"Generated sequence shape: {generated_sequence.shape}")
print(f"Generated tokens: {generated_sequence[0].tolist()}")

## Part 5: Transformer vs RNN Comparison

### 5.1 Speed Comparison

In [None]:
class SimpleRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size=256, n_layers=2):
        super(SimpleRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.rnn(x)
        return self.fc(x)

def compare_speed(seq_lengths=[10, 50, 100, 200, 500]):
    """
    Compare processing speed of Transformer vs RNN.
    """
    vocab_size = 5000
    batch_size = 32
    
    transformer = MiniTransformer(vocab_size, d_model=128, n_heads=4, n_layers=2).to(device)
    rnn = SimpleRNN(vocab_size, hidden_size=128, n_layers=2).to(device)
    
    transformer_times = []
    rnn_times = []
    
    for seq_len in seq_lengths:
        x = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
        
        # Time Transformer
        start = time.time()
        with torch.no_grad():
            _ = transformer(x)
        transformer_times.append(time.time() - start)
        
        # Time RNN
        start = time.time()
        with torch.no_grad():
            _ = rnn(x)
        rnn_times.append(time.time() - start)
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(seq_lengths, transformer_times, 'o-', label='Transformer', linewidth=2, markersize=8)
    plt.plot(seq_lengths, rnn_times, 's-', label='RNN', linewidth=2, markersize=8)
    plt.xlabel('Sequence Length', fontsize=12)
    plt.ylabel('Processing Time (seconds)', fontsize=12)
    plt.title('Transformer vs RNN Speed Comparison', fontsize=14, fontweight='bold')
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return transformer_times, rnn_times

transformer_times, rnn_times = compare_speed()

### 5.2 Attention Range Analysis

In [None]:
def analyze_attention_range(model, seq_len=50):
    """
    Analyze how far transformer can attend.
    """
    model.eval()
    x = torch.randint(0, 1000, (1, seq_len)).to(device)
    
    with torch.no_grad():
        _, attention_weights = model(x)
    
    # Get attention from last layer
    last_layer_attention = attention_weights[-1][0].mean(dim=0).cpu().numpy()
    
    # Calculate average attention distance
    distances = []
    for i in range(seq_len):
        for j in range(seq_len):
            if last_layer_attention[i, j] > 0.1:  # Threshold
                distances.append(abs(i - j))
    
    plt.figure(figsize=(12, 5))
    
    # Subplot 1: Attention heatmap
    plt.subplot(1, 2, 1)
    plt.imshow(last_layer_attention, cmap='YlOrRd', aspect='auto')
    plt.colorbar()
    plt.title('Attention Pattern (Last Layer Average)')
    plt.xlabel('Position')
    plt.ylabel('Position')
    
    # Subplot 2: Distance histogram
    plt.subplot(1, 2, 2)
    plt.hist(distances, bins=20, color='skyblue', edgecolor='black', alpha=0.7)
    plt.xlabel('Attention Distance')
    plt.ylabel('Frequency')
    plt.title('Distribution of Attention Distances')
    plt.axvline(np.mean(distances), color='red', linestyle='--', 
                label=f'Mean: {np.mean(distances):.1f}')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    print(f"Average attention distance: {np.mean(distances):.2f}")
    print(f"Max attention distance: {np.max(distances)}")

analyze_attention_range(model)

## Part 6: Positional Encoding Explorer

### 6.1 Interactive Position Encoding

In [None]:
def explore_positional_encoding(d_model_values=[64, 128, 256, 512]):
    """
    Explore how positional encoding changes with model dimension.
    """
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    
    for idx, d_model in enumerate(d_model_values):
        pe = PositionalEncoding(d_model, max_len=100)
        encoding = pe.pe[:50, 0, :].numpy()
        
        im = axes[idx].imshow(encoding.T, cmap='RdBu', aspect='auto', vmin=-1, vmax=1)
        axes[idx].set_title(f'd_model = {d_model}')
        axes[idx].set_xlabel('Position')
        axes[idx].set_ylabel('Dimension')
        plt.colorbar(im, ax=axes[idx])
    
    plt.suptitle('Positional Encoding Patterns for Different Model Dimensions', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

explore_positional_encoding()

### 6.2 Effect of Removing Positional Encoding

In [None]:
def test_without_positional_encoding():
    """
    Show what happens without positional encoding.
    """
    # Create two models: with and without PE
    vocab_size = 100
    d_model = 64
    seq_len = 10
    
    # Test sequences that are permutations
    seq1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    seq2 = torch.tensor([[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]])  # Reversed
    seq3 = torch.tensor([[5, 3, 8, 1, 10, 2, 7, 4, 9, 6]])  # Shuffled
    
    # Model with positional encoding
    model_with_pe = MiniTransformer(vocab_size, d_model=d_model, n_heads=4, n_layers=1)
    model_with_pe.eval()
    
    # Model without positional encoding (hack: set PE to zero)
    model_without_pe = MiniTransformer(vocab_size, d_model=d_model, n_heads=4, n_layers=1)
    model_without_pe.pos_encoding.pe.data.fill_(0)  # Zero out positional encoding
    model_without_pe.eval()
    
    with torch.no_grad():
        # With PE
        out1_pe, _ = model_with_pe(seq1)
        out2_pe, _ = model_with_pe(seq2)
        out3_pe, _ = model_with_pe(seq3)
        
        # Without PE
        out1_no_pe, _ = model_without_pe(seq1)
        out2_no_pe, _ = model_without_pe(seq2)
        out3_no_pe, _ = model_without_pe(seq3)
    
    # Compare outputs
    print("With Positional Encoding:")
    print(f"  Difference between original and reversed: {torch.mean(torch.abs(out1_pe - out2_pe)):.4f}")
    print(f"  Difference between original and shuffled: {torch.mean(torch.abs(out1_pe - out3_pe)):.4f}")
    
    print("\nWithout Positional Encoding:")
    print(f"  Difference between original and reversed: {torch.mean(torch.abs(out1_no_pe - out2_no_pe)):.4f}")
    print(f"  Difference between original and shuffled: {torch.mean(torch.abs(out1_no_pe - out3_no_pe)):.4f}")
    
    print("\nðŸ’¡ Insight: Without PE, the model can't distinguish between different orders!")

test_without_positional_encoding()

## Part 7: Advanced Experiments

### 7.1 Attention Masking for Autoregressive Generation

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal mask for autoregressive generation.
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask == 0

def visualize_causal_mask(seq_len=10):
    mask = create_causal_mask(seq_len)
    
    plt.figure(figsize=(8, 6))
    plt.imshow(mask, cmap='RdYlGn', aspect='auto')
    plt.colorbar(label='Can Attend')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.title('Causal Mask: Each Position Can Only Attend to Previous Positions')
    
    # Add annotations
    for i in range(seq_len):
        for j in range(seq_len):
            text = 'âœ“' if mask[i, j] else 'âœ—'
            color = 'white' if mask[i, j] else 'black'
            plt.text(j, i, text, ha='center', va='center', color=color, fontsize=12)
    
    plt.tight_layout()
    plt.show()

visualize_causal_mask(8)

### 7.2 Gradient Flow Analysis

In [None]:
def analyze_gradient_flow(model, seq_len=20):
    """
    Analyze gradient flow through transformer layers.
    """
    model.train()
    
    # Forward pass
    x = torch.randint(0, 1000, (1, seq_len)).to(device)
    output, _ = model(x)
    
    # Create a dummy loss and backward
    loss = output.mean()
    loss.backward()
    
    # Collect gradient magnitudes
    grad_magnitudes = []
    layer_names = []
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_magnitudes.append(param.grad.abs().mean().item())
            layer_names.append(name.split('.')[0])
    
    # Plot gradient flow
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(grad_magnitudes)), grad_magnitudes)
    plt.xlabel('Layer')
    plt.ylabel('Average Gradient Magnitude')
    plt.title('Gradient Flow Through Transformer Layers')
    plt.xticks(range(len(layer_names)), layer_names, rotation=45, ha='right')
    plt.yscale('log')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"Max gradient: {max(grad_magnitudes):.6f}")
    print(f"Min gradient: {min(grad_magnitudes):.6f}")
    print(f"Gradient ratio (max/min): {max(grad_magnitudes)/min(grad_magnitudes):.2f}")

analyze_gradient_flow(model)

## Summary and Key Takeaways

### What We've Learned:
1. **Self-Attention**: The core mechanism that lets each word look at all other words
2. **Multi-Head Attention**: Different heads capture different types of relationships
3. **Positional Encoding**: Essential for understanding word order
4. **Parallelization**: Transformers process all positions simultaneously
5. **Gradient Flow**: Residual connections help with training deep models

### Key Insights:
- Transformers are **faster** than RNNs for long sequences
- They can capture **long-range dependencies** effectively
- Without positional encoding, they're **permutation invariant**
- The architecture is surprisingly **simple and elegant**

### Your Mini-Transformer:
- Has ~2.5M parameters (GPT-3 has 175B!)
- Can process sequences in parallel
- Uses the same architecture as state-of-the-art models
- Ready for training on real data!

### Next Steps:
1. Train your transformer on real text data
2. Experiment with different architectures
3. Try pre-training and fine-tuning
4. Scale up to bigger models!

In [None]:
print("ðŸŽ‰ Congratulations! You've built your own Transformer from scratch!")
print("\nðŸ“Š Your achievements:")
print("âœ… Implemented scaled dot-product attention")
print("âœ… Built multi-head attention mechanism")
print("âœ… Created positional encoding")
print("âœ… Assembled a complete transformer block")
print("âœ… Compared performance with RNNs")
print("âœ… Visualized attention patterns")
print("\nðŸš€ You're ready to tackle Week 6: Pre-trained Language Models!")