# Day 7: Multi-Head Attention - Implementation and Examples

This notebook contains all the code examples, visualizations, and hands-on exercises for Day 7 of Week 2.

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set up plotting
plt.style.use('default')
sns.set_palette('husl')

## 1. Multi-Head Attention Implementation

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention implementation."""
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """Scaled dot-product attention for one head."""
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        output = torch.matmul(attention_weights, V)
        return output, attention_weights
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len = query.size(0), query.size(1)
        
        # 1. Linear projections
        Q = self.w_q(query)  # [batch_size, seq_len, d_model]
        K = self.w_k(key)
        V = self.w_v(value)
        
        # 2. Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # Shape: [batch_size, num_heads, seq_len, d_k]
        
        # 3. Apply attention to each head
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # 4. Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model)
        
        # 5. Final linear projection
        output = self.w_o(attention_output)
        
        return output, attention_weights

In [None]:
def test_multihead_attention():
    """Test multi-head attention implementation."""
    
    batch_size, seq_len, d_model = 2, 6, 64
    num_heads = 8
    
    # Create input
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Initialize multi-head attention
    mha = MultiHeadAttention(d_model, num_heads)
    
    # Forward pass
    output, attention_weights = mha(x, x, x)
    
    print("Multi-Head Attention Test:")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {attention_weights.shape}")
    print(f"Number of heads: {num_heads}")
    print(f"d_k per head: {d_model // num_heads}")
    
    return output, attention_weights

output, weights = test_multihead_attention()

## 2. Visualizing Different Attention Heads

In [None]:
def visualize_attention_heads():
    """Visualize how different heads capture different patterns."""
    
    # Create structured input representing: "The cat sat on the mat"
    tokens = ["The", "cat", "sat", "on", "the", "mat"]
    seq_len, d_model = len(tokens), 64
    num_heads = 4
    
    # Create embeddings with linguistic structure
    torch.manual_seed(42)
    embeddings = torch.randn(1, seq_len, d_model)
    
    # Add structure to simulate different relationships
    # Articles similar
    embeddings[0, 4] = embeddings[0, 0] + 0.2 * torch.randn(d_model)
    # Nouns similar  
    embeddings[0, 5] = embeddings[0, 1] + 0.3 * torch.randn(d_model)
    # Verb-noun relationship
    embeddings[0, 2] += 0.1 * embeddings[0, 1]  # sat influenced by cat
    
    # Apply multi-head attention
    mha = MultiHeadAttention(d_model, num_heads)
    output, attention_weights = mha(embeddings, embeddings, embeddings)
    
    # Visualize each head
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    
    for head in range(num_heads):
        head_weights = attention_weights[0, head].detach().numpy()
        
        sns.heatmap(head_weights, 
                   xticklabels=tokens, yticklabels=tokens,
                   annot=True, fmt='.2f', cmap='Blues',
                   ax=axes[head])
        axes[head].set_title(f'Head {head + 1}')
        axes[head].set_xlabel('Key (attending to)')
        axes[head].set_ylabel('Query (attending from)')
    
    plt.tight_layout()
    plt.show()
    
    # Analyze head specialization
    print("Head Specialization Analysis:")
    print("=" * 40)
    
    for head in range(num_heads):
        head_weights = attention_weights[0, head]
        
        # Find strongest attention patterns
        max_attention = head_weights.max()
        max_pos = torch.argmax(head_weights.flatten())
        i, j = max_pos // seq_len, max_pos % seq_len
        
        # Calculate attention entropy (how distributed)
        entropy = -torch.sum(head_weights * torch.log(head_weights + 1e-9), dim=-1).mean()
        
        print(f"Head {head + 1}:")
        print(f"  Strongest: '{tokens[i]}' → '{tokens[j]}' ({max_attention:.3f})")
        print(f"  Avg entropy: {entropy:.3f} ({'focused' if entropy < 1.5 else 'distributed'})")
        print()
    
    return attention_weights, tokens

head_weights, tokens = visualize_attention_heads()

## 3. Head Specialization Analysis

In [None]:
def analyze_head_specialization():
    """Analyze what different heads learn to focus on."""
    
    # Create different types of sequences to test specialization
    test_cases = {
        'syntactic': ["The", "quick", "brown", "fox", "jumps", "high"],
        'semantic': ["cat", "dog", "animal", "pet", "furry", "cute"],
        'positional': ["first", "second", "third", "fourth", "fifth", "sixth"]
    }
    
    d_model, num_heads = 48, 6
    mha = MultiHeadAttention(d_model, num_heads)
    
    results = {}
    
    for case_name, tokens in test_cases.items():
        seq_len = len(tokens)
        
        # Create embeddings with different structures
        torch.manual_seed(123)
        embeddings = torch.randn(1, seq_len, d_model)
        
        if case_name == 'syntactic':
            # Make adjectives similar
            for i in [1, 2]:  # quick, brown
                embeddings[0, i] += 0.3 * embeddings[0, 1]
        elif case_name == 'semantic':
            # Make semantically related words similar
            embeddings[0, 1] = embeddings[0, 0] + 0.2 * torch.randn(d_model)  # cat, dog
            embeddings[0, 3] = embeddings[0, 2] + 0.2 * torch.randn(d_model)  # animal, pet
        elif case_name == 'positional':
            # Create positional patterns
            for i in range(seq_len):
                embeddings[0, i] += 0.1 * i * torch.ones(d_model)
        
        # Get attention patterns
        output, attention_weights = mha(embeddings, embeddings, embeddings)
        
        # Analyze each head
        head_analysis = {}
        for head in range(num_heads):
            head_attn = attention_weights[0, head]
            
            # Measure different properties
            self_attention = torch.diag(head_attn).mean().item()
            local_attention = sum(head_attn[i, max(0, i-1):min(seq_len, i+2)].sum() 
                                for i in range(seq_len)) / seq_len
            global_attention = head_attn.sum() - local_attention * seq_len
            
            head_analysis[head] = {
                'self_attention': self_attention,
                'local_attention': local_attention.item(),
                'global_attention': global_attention.item() / seq_len
            }
        
        results[case_name] = head_analysis
    
    # Display results
    print("Head Specialization Across Different Sequence Types:")
    print("=" * 60)
    
    for case_name, head_data in results.items():
        print(f"\n{case_name.upper()} Sequences:")
        print("Head | Self-Attn | Local-Attn | Global-Attn")
        print("-" * 45)
        
        for head, metrics in head_data.items():
            print(f"{head+1:4d} | {metrics['self_attention']:8.3f} | "
                  f"{metrics['local_attention']:9.3f} | {metrics['global_attention']:10.3f}")
    
    return results

specialization_results = analyze_head_specialization()

## 4. Concatenation and Projection Deep Dive

In [None]:
def demonstrate_concatenation_projection():
    """Show the concatenation and projection process in detail."""
    
    print("Multi-Head Concatenation and Projection")
    print("=" * 45)
    
    batch_size, seq_len, d_model = 1, 4, 12
    num_heads = 3
    d_k = d_model // num_heads  # 4
    
    # Simulate attention outputs from different heads
    torch.manual_seed(42)
    head_outputs = []
    
    for head in range(num_heads):
        # Each head outputs [batch_size, seq_len, d_k]
        head_output = torch.randn(batch_size, seq_len, d_k)
        head_outputs.append(head_output)
        print(f"Head {head + 1} output shape: {head_output.shape}")
    
    print(f"\nEach head dimension d_k: {d_k}")
    print(f"Total model dimension d_model: {d_model}")
    
    # Step 1: Concatenation
    concatenated = torch.cat(head_outputs, dim=-1)
    print(f"\nAfter concatenation shape: {concatenated.shape}")
    print("Concatenated output (first token, all dimensions):")
    print(concatenated[0, 0].detach().numpy().round(3))
    
    # Step 2: Linear projection
    w_o = nn.Linear(d_model, d_model)
    final_output = w_o(concatenated)
    
    print(f"\nAfter projection shape: {final_output.shape}")
    print("Final output (first token, first 6 dims):")
    print(final_output[0, 0, :6].detach().numpy().round(3))
    
    # Visualize the process
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Head outputs
    for i, head_output in enumerate(head_outputs):
        axes[0].imshow(head_output[0].detach().numpy().T, 
                      cmap='RdBu', aspect='auto')
        axes[0].axvline(x=i*seq_len + seq_len - 0.5, color='white', linewidth=2)
    axes[0].set_title('Individual Head Outputs')
    axes[0].set_xlabel('Sequence Position')
    axes[0].set_ylabel('Feature Dimension')
    
    # Concatenated
    axes[1].imshow(concatenated[0].detach().numpy().T, 
                  cmap='RdBu', aspect='auto')
    axes[1].set_title('Concatenated Output')
    axes[1].set_xlabel('Sequence Position')
    axes[1].set_ylabel('Feature Dimension')
    
    # Final projection
    axes[2].imshow(final_output[0].detach().numpy().T, 
                  cmap='RdBu', aspect='auto')
    axes[2].set_title('After Linear Projection')
    axes[2].set_xlabel('Sequence Position')
    axes[2].set_ylabel('Feature Dimension')
    
    plt.tight_layout()
    plt.show()
    
    return concatenated, final_output

concat_output, final_output = demonstrate_concatenation_projection()

## 5. Comparing Single vs Multi-Head

In [None]:
def compare_single_vs_multihead():
    """Compare single-head vs multi-head attention performance."""
    
    seq_len, d_model = 8, 64
    
    # Create input with complex relationships
    torch.manual_seed(42)
    x = torch.randn(1, seq_len, d_model)
    
    # Add multiple types of relationships
    # Local dependencies
    for i in range(seq_len - 1):
        x[0, i+1] += 0.2 * x[0, i]
    
    # Long-range dependencies  
    x[0, -1] += 0.3 * x[0, 0]
    x[0, -2] += 0.3 * x[0, 1]
    
    # Single-head attention
    single_head = MultiHeadAttention(d_model, num_heads=1)
    single_output, single_weights = single_head(x, x, x)
    
    # Multi-head attention
    multi_head = MultiHeadAttention(d_model, num_heads=8)
    multi_output, multi_weights = multi_head(x, x, x)
    
    # Visualize attention patterns
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Single head
    sns.heatmap(single_weights[0, 0].detach().numpy(),
               annot=True, fmt='.2f', cmap='Blues', ax=axes[0])
    axes[0].set_title('Single Head Attention')
    
    # Multi-head (average)
    avg_multi_weights = multi_weights[0].mean(dim=0)
    sns.heatmap(avg_multi_weights.detach().numpy(),
               annot=True, fmt='.2f', cmap='Blues', ax=axes[1])
    axes[1].set_title('Multi-Head (Average)')
    
    # Multi-head diversity (std across heads)
    std_multi_weights = multi_weights[0].std(dim=0)
    sns.heatmap(std_multi_weights.detach().numpy(),
               annot=True, fmt='.2f', cmap='Reds', ax=axes[2])
    axes[2].set_title('Multi-Head Diversity (Std)')
    
    plt.tight_layout()
    plt.show()
    
    # Quantitative comparison
    print("Single vs Multi-Head Comparison:")
    print("=" * 35)
    
    # Attention diversity
    single_entropy = -torch.sum(single_weights * torch.log(single_weights + 1e-9), dim=-1).mean()
    multi_entropy = -torch.sum(multi_weights * torch.log(multi_weights + 1e-9), dim=-1).mean()
    
    print(f"Single-head entropy: {single_entropy:.3f}")
    print(f"Multi-head entropy: {multi_entropy:.3f}")
    
    # Output difference
    output_diff = torch.norm(single_output - multi_output).item()
    print(f"Output difference norm: {output_diff:.3f}")
    
    # Head diversity in multi-head
    head_similarities = []
    for i in range(8):
        for j in range(i+1, 8):
            sim = F.cosine_similarity(
                multi_weights[0, i].flatten(),
                multi_weights[0, j].flatten(),
                dim=0
            )
            head_similarities.append(sim.item())
    
    avg_head_similarity = np.mean(head_similarities)
    print(f"Average head similarity: {avg_head_similarity:.3f}")
    print(f"Head diversity: {'High' if avg_head_similarity < 0.5 else 'Low'}")

compare_single_vs_multihead()