# Multi-Head Attention: Parallel Attention Patterns

---

## What You'll Learn

1. **Single-head attention** from scratch in NumPy and PyTorch
2. **Why multiple heads matter** - different heads capture different linguistic relationships
3. **Multi-head attention** - how heads are split, computed, concatenated, and projected
4. **Visualizing attention patterns** - heatmaps showing what each head focuses on
5. **Comparing single-head vs multi-head** on a practical task

---

### The Core Idea

Single-head attention computes **one** set of attention weights. This forces all the information (subject-verb agreement, co-reference, positional relationships) into a single attention pattern.

Multi-head attention runs **multiple attention computations in parallel**, each with its own learned projection. Different heads can specialize:
- Head 1 might learn subject-verb relationships
- Head 2 might learn co-reference (pronouns -> nouns)
- Head 3 might learn positional/local context

The outputs are concatenated and projected back to the model dimension.

In [None]:
# Install dependencies
!pip install torch numpy matplotlib seaborn transformers -q

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple

np.random.seed(42)
torch.manual_seed(42)

print("All imports ready!")

## Part 1: Single-Head Attention in NumPy

Let's build attention from absolute first principles using only NumPy.

**Scaled Dot-Product Attention:**

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

Where:
- $Q$ = Query matrix (what am I looking for?)
- $K$ = Key matrix (what do I contain?)
- $V$ = Value matrix (what information do I carry?)
- $d_k$ = dimension of keys (for scaling)

In [None]:
def softmax_numpy(x, axis=-1):
    """Numerically stable softmax in NumPy."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def single_head_attention_numpy(Q, K, V, mask=None):
    """
    Single-head scaled dot-product attention in NumPy.
    
    Args:
        Q: Query matrix, shape (seq_len, d_k)
        K: Key matrix, shape (seq_len, d_k)
        V: Value matrix, shape (seq_len, d_v)
        mask: Optional causal mask
    
    Returns:
        output: shape (seq_len, d_v)
        attention_weights: shape (seq_len, seq_len)
    """
    d_k = Q.shape[-1]
    
    # Step 1: Compute raw attention scores
    # Q @ K^T -> (seq_len, seq_len)
    scores = Q @ K.T
    
    # Step 2: Scale by sqrt(d_k) to prevent softmax saturation
    scores = scores / np.sqrt(d_k)
    
    # Step 3: Apply mask (if causal attention)
    if mask is not None:
        scores = np.where(mask == 0, -1e9, scores)
    
    # Step 4: Softmax to get attention weights
    attention_weights = softmax_numpy(scores, axis=-1)
    
    # Step 5: Weighted sum of values
    output = attention_weights @ V
    
    return output, attention_weights


# Test with a small example
seq_len = 5
d_k = 4  # dimension of keys/queries
d_v = 4  # dimension of values

# Random embeddings (pretend these are word embeddings)
Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)

output, attn_weights = single_head_attention_numpy(Q, K, V)

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nAttention weights (rows sum to 1):")
print(np.round(attn_weights, 3))
print(f"\nRow sums: {attn_weights.sum(axis=-1)}")

### Visualizing Single-Head Attention

Let's use a real sentence and see what single-head attention looks like.

In [None]:
def visualize_attention(attention_weights, tokens, title="Attention Weights", ax=None):
    """Plot attention weights as a heatmap."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    
    sns.heatmap(
        attention_weights,
        xticklabels=tokens,
        yticklabels=tokens,
        annot=True,
        fmt='.2f',
        cmap='Blues',
        ax=ax,
        vmin=0,
        vmax=1
    )
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Key (attending to)', fontsize=11)
    ax.set_ylabel('Query (attending from)', fontsize=11)
    
    return ax


# Simulate a sentence with meaningful tokens
tokens = ["The", "cat", "sat", "on", "mat"]
seq_len = len(tokens)
d_model = 8

# Create pseudo-embeddings with some structure
np.random.seed(42)
embeddings = np.random.randn(seq_len, d_model)

# Single projection matrices
W_q = np.random.randn(d_model, d_model) * 0.1
W_k = np.random.randn(d_model, d_model) * 0.1
W_v = np.random.randn(d_model, d_model) * 0.1

Q = embeddings @ W_q
K = embeddings @ W_k
V = embeddings @ W_v

output, attn_weights = single_head_attention_numpy(Q, K, V)

fig, ax = plt.subplots(figsize=(8, 6))
visualize_attention(attn_weights, tokens, "Single-Head Attention", ax)
plt.tight_layout()
plt.show()

print("\nNotice: A single head must encode ALL relationships in one pattern.")
print("It cannot separately capture syntax AND semantics AND position.")

## Part 2: Multi-Head Attention in NumPy

The key insight of multi-head attention:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O$$

where $\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)$

**Instead of one big attention with $d_{model}$ dimensions, we split into $h$ heads, each with $d_k = d_{model} / h$ dimensions.**

This gives each head a "subspace" to specialize in.

In [None]:
def multi_head_attention_numpy(X, n_heads, W_q, W_k, W_v, W_o, mask=None):
    """
    Multi-head attention in NumPy.
    
    Args:
        X: Input embeddings, shape (seq_len, d_model)
        n_heads: Number of attention heads
        W_q, W_k, W_v: Projection matrices, shape (d_model, d_model)
        W_o: Output projection, shape (d_model, d_model)
        mask: Optional causal mask
    
    Returns:
        output: shape (seq_len, d_model)
        all_attention_weights: list of attention weight matrices per head
    """
    seq_len, d_model = X.shape
    d_k = d_model // n_heads  # dimension per head
    
    # Step 1: Project input to Q, K, V
    Q = X @ W_q  # (seq_len, d_model)
    K = X @ W_k
    V = X @ W_v
    
    # Step 2: Split into heads
    # Reshape from (seq_len, d_model) to (seq_len, n_heads, d_k)
    Q_heads = Q.reshape(seq_len, n_heads, d_k)
    K_heads = K.reshape(seq_len, n_heads, d_k)
    V_heads = V.reshape(seq_len, n_heads, d_k)
    
    # Step 3: Compute attention for each head
    head_outputs = []
    all_attention_weights = []
    
    for h in range(n_heads):
        Q_h = Q_heads[:, h, :]  # (seq_len, d_k)
        K_h = K_heads[:, h, :]
        V_h = V_heads[:, h, :]
        
        output_h, attn_h = single_head_attention_numpy(Q_h, K_h, V_h, mask)
        head_outputs.append(output_h)
        all_attention_weights.append(attn_h)
    
    # Step 4: Concatenate all head outputs
    # Each head output is (seq_len, d_k), concatenate to (seq_len, d_model)
    concat = np.concatenate(head_outputs, axis=-1)
    
    # Step 5: Final linear projection
    output = concat @ W_o
    
    return output, all_attention_weights


# Set up dimensions
d_model = 8
n_heads = 4
d_k = d_model // n_heads  # = 2 per head

print(f"d_model = {d_model}")
print(f"n_heads = {n_heads}")
print(f"d_k per head = {d_k}")
print(f"Total params per head: {d_k * d_k} (just the attention part)")
print(f"\nKey insight: {n_heads} heads x {d_k} dims = {n_heads * d_k} = d_model")
print("Multi-head attention does NOT add parameters - it REDISTRIBUTES them!")

In [None]:
# Run multi-head attention on our sentence
np.random.seed(42)

tokens = ["The", "cat", "sat", "on", "mat"]
seq_len = len(tokens)
d_model = 8
n_heads = 4

# Input embeddings
X = np.random.randn(seq_len, d_model)

# Projection matrices
W_q = np.random.randn(d_model, d_model) * 0.1
W_k = np.random.randn(d_model, d_model) * 0.1
W_v = np.random.randn(d_model, d_model) * 0.1
W_o = np.random.randn(d_model, d_model) * 0.1

output, all_attn = multi_head_attention_numpy(X, n_heads, W_q, W_k, W_v, W_o)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of attention heads: {len(all_attn)}")
print(f"Each head's attention shape: {all_attn[0].shape}")

In [None]:
# Visualize all heads side by side
fig, axes = plt.subplots(1, n_heads, figsize=(24, 5))

for h in range(n_heads):
    visualize_attention(
        all_attn[h], tokens, 
        f"Head {h+1}", 
        axes[h]
    )

plt.suptitle("Multi-Head Attention: Each Head Learns Different Patterns", 
             fontsize=16, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

print("Notice how different heads attend to different positions!")
print("Each head can specialize in capturing a different type of relationship.")

## Part 3: Multi-Head Attention in PyTorch

Now let's implement the same thing in PyTorch, which is how it's done in practice. We'll build it from scratch (not using `nn.MultiheadAttention`) to understand every step.

In [None]:
class SingleHeadAttention(nn.Module):
    """Single-head scaled dot-product attention."""
    
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x, mask=None):
        # x: (batch, seq_len, d_model)
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_model ** 0.5)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        output = self.W_o(output)
        
        return output, attn_weights


# Test
d_model = 64
batch_size = 1
seq_len = 5

model = SingleHeadAttention(d_model)
x = torch.randn(batch_size, seq_len, d_model)

output, weights = model(x)
print(f"Single-Head Attention:")
print(f"  Input:  {x.shape}")
print(f"  Output: {output.shape}")
print(f"  Weights: {weights.shape}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention from scratch.
    
    Key insight: We use a SINGLE large linear layer for all heads,
    then reshape to split into heads. This is more efficient than
    separate linear layers per head.
    """
    
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # dimension per head
        
        # Single projection for all heads (more efficient)
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def split_heads(self, x):
        """Reshape (batch, seq_len, d_model) -> (batch, n_heads, seq_len, d_k)"""
        batch_size, seq_len, _ = x.shape
        x = x.view(batch_size, seq_len, self.n_heads, self.d_k)
        return x.transpose(1, 2)  # (batch, n_heads, seq_len, d_k)
    
    def combine_heads(self, x):
        """Reshape (batch, n_heads, seq_len, d_k) -> (batch, seq_len, d_model)"""
        batch_size, _, seq_len, _ = x.shape
        x = x.transpose(1, 2)  # (batch, seq_len, n_heads, d_k)
        return x.contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # Step 1: Project to Q, K, V
        Q = self.W_q(x)  # (batch, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Step 2: Split into heads
        Q = self.split_heads(Q)  # (batch, n_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Step 3: Scaled dot-product attention (all heads in parallel!)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        # scores: (batch, n_heads, seq_len, seq_len)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        # attn_weights: (batch, n_heads, seq_len, seq_len)
        
        # Step 4: Apply attention to values
        context = torch.matmul(attn_weights, V)
        # context: (batch, n_heads, seq_len, d_k)
        
        # Step 5: Concatenate heads
        context = self.combine_heads(context)
        # context: (batch, seq_len, d_model)
        
        # Step 6: Final projection
        output = self.W_o(context)
        
        return output, attn_weights


# Test
d_model = 64
n_heads = 8

mha = MultiHeadAttention(d_model, n_heads)
x = torch.randn(1, seq_len, d_model)

output, weights = mha(x)
print(f"Multi-Head Attention ({n_heads} heads):")
print(f"  Input:  {x.shape}")
print(f"  Output: {output.shape}")
print(f"  Weights: {weights.shape} = (batch, n_heads, seq_len, seq_len)")
print(f"  d_k per head: {d_model // n_heads}")
print(f"  Parameters: {sum(p.numel() for p in mha.parameters()):,}")

### The Parameter Count is the Same!

Notice that single-head and multi-head attention have the **same number of parameters** when using the same `d_model`. Multi-head attention doesn't add parameters - it **reorganizes** them into parallel subspaces.

In [None]:
# Demonstrate: same parameter count
d_model = 64

single = SingleHeadAttention(d_model)
multi_4 = MultiHeadAttention(d_model, 4)
multi_8 = MultiHeadAttention(d_model, 8)

for name, model in [("Single Head", single), ("4 Heads", multi_4), ("8 Heads", multi_8)]:
    params = sum(p.numel() for p in model.parameters())
    print(f"{name:>12}: {params:,} parameters")

print("\n=> Same parameters, different organization!")
print("   More heads = more subspaces = more specialization")
print("   But each head has fewer dimensions to work with")

## Part 4: Visualizing What Different Heads Learn

Let's use a real pretrained model to see how different heads specialize. We'll extract attention patterns from GPT-2.

In [None]:
from transformers import GPT2Tokenizer, GPT2Model

# Load GPT-2 (small, runs on Colab free tier)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2', output_attentions=True)
model.eval()

print(f"GPT-2 config:")
print(f"  d_model: {model.config.n_embd}")
print(f"  n_heads: {model.config.n_head}")
print(f"  n_layers: {model.config.n_layer}")
print(f"  d_k per head: {model.config.n_embd // model.config.n_head}")

In [None]:
def get_attention_patterns(text, model, tokenizer):
    """Extract attention patterns from GPT-2."""
    inputs = tokenizer(text, return_tensors='pt')
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # outputs.attentions is a tuple of (n_layers) tensors
    # Each tensor: (batch, n_heads, seq_len, seq_len)
    return tokens, outputs.attentions


# Test sentence designed to show different linguistic patterns
text = "The cat sat on the mat because it was tired"
tokens, attentions = get_attention_patterns(text, model, tokenizer)

print(f"Tokens: {tokens}")
print(f"Number of layers: {len(attentions)}")
print(f"Attention shape per layer: {attentions[0].shape}")

In [None]:
# Visualize attention patterns from different heads in Layer 0
layer_idx = 0
layer_attn = attentions[layer_idx][0].detach().numpy()  # (n_heads, seq_len, seq_len)

# Show 4 heads from layer 0
fig, axes = plt.subplots(2, 2, figsize=(16, 14))
axes = axes.flatten()

# Clean up token names for display
clean_tokens = [t.replace('\u0120', ' ') for t in tokens]

for h, ax in enumerate(axes):
    sns.heatmap(
        layer_attn[h],
        xticklabels=clean_tokens,
        yticklabels=clean_tokens,
        cmap='Blues',
        ax=ax,
        vmin=0,
        vmax=1
    )
    ax.set_title(f'Layer {layer_idx}, Head {h}', fontsize=13, fontweight='bold')
    ax.tick_params(axis='x', rotation=45)
    ax.tick_params(axis='y', rotation=0)

plt.suptitle(f'GPT-2 Attention Heads - Layer {layer_idx}\n"{text}"', 
             fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Visualize specific patterns: look for heads that show
# 1. Local attention (nearby tokens)
# 2. Position-based attention (attending to first token)
# 3. Content-based attention (specific relationships)

def analyze_attention_pattern(attn_matrix, tokens):
    """Classify what pattern a head shows."""
    seq_len = len(tokens)
    
    # Check for local attention (diagonal-heavy)
    diag_strength = np.mean([attn_matrix[i, max(0, i-1):i+1].sum() 
                            for i in range(seq_len)])
    
    # Check for position attention (first token)
    first_token_attn = np.mean(attn_matrix[:, 0])
    
    # Check for previous token attention
    prev_token_attn = np.mean([attn_matrix[i, i-1] 
                               for i in range(1, seq_len)])
    
    return {
        'local': diag_strength,
        'first_token': first_token_attn,
        'previous_token': prev_token_attn
    }


# Analyze all heads in layer 0
print("Attention Pattern Analysis (Layer 0):")
print("=" * 60)

for h in range(model.config.n_head):
    patterns = analyze_attention_pattern(layer_attn[h], clean_tokens)
    dominant = max(patterns, key=patterns.get)
    print(f"Head {h:2d}: Local={patterns['local']:.3f}  "
          f"First={patterns['first_token']:.3f}  "
          f"Prev={patterns['previous_token']:.3f}  "
          f"-> Dominant: {dominant}")

## Part 5: Subject-Verb and Co-reference Patterns

Let's look at deeper layers where more complex linguistic patterns emerge, like:
- **Subject-verb agreement**: "The cats **are** sleeping" (head linking "cats" to "are")
- **Co-reference**: "The cat... **it** was tired" (head linking "it" to "cat")

In [None]:
# Test with a sentence designed for co-reference
text_coref = "The doctor said that she would help the patient"
tokens_coref, attentions_coref = get_attention_patterns(text_coref, model, tokenizer)
clean_coref = [t.replace('\u0120', ' ') for t in tokens_coref]

# Look at a middle layer (layers 4-6 tend to show syntactic patterns)
fig, axes = plt.subplots(2, 4, figsize=(28, 12))

for layer_idx, row_axes in zip([4, 8], axes):
    layer_attn = attentions_coref[layer_idx][0].detach().numpy()
    
    for h, ax in enumerate(row_axes):
        sns.heatmap(
            layer_attn[h],
            xticklabels=clean_coref,
            yticklabels=clean_coref,
            cmap='Blues',
            ax=ax,
            vmin=0,
            vmax=0.6
        )
        ax.set_title(f'Layer {layer_idx}, Head {h}', fontsize=11, fontweight='bold')
        ax.tick_params(axis='x', rotation=45, labelsize=8)
        ax.tick_params(axis='y', rotation=0, labelsize=8)

plt.suptitle(f'Attention Heads Across Layers\n"{text_coref}"\nLook for "she" attending to "doctor" (co-reference)', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Let's specifically look at where "she" attends across all layers and heads
she_idx = clean_coref.index(' she') if ' she' in clean_coref else None
if she_idx is None:
    she_idx = clean_coref.index('she') if 'she' in clean_coref else None

# Find the index of 'doctor' related token
print(f"Tokens: {clean_coref}")
print(f"'she' is at index: {she_idx}")

# For each layer, find the head where 'she' attends most to 'doctor'
doctor_idx = None
for i, t in enumerate(clean_coref):
    if 'doctor' in t.lower():
        doctor_idx = i
        break

print(f"'doctor' is at index: {doctor_idx}")

if she_idx is not None and doctor_idx is not None:
    print(f"\nAttention from 'she' -> 'doctor' across layers and heads:")
    print("=" * 60)
    
    best_layers = []
    for layer_idx in range(len(attentions_coref)):
        layer_attn = attentions_coref[layer_idx][0].detach().numpy()
        for h in range(layer_attn.shape[0]):
            attn_val = layer_attn[h, she_idx, doctor_idx]
            if attn_val > 0.1:  # Significant attention
                best_layers.append((layer_idx, h, attn_val))
                print(f"  Layer {layer_idx:2d}, Head {h:2d}: {attn_val:.4f} {'***' if attn_val > 0.2 else ''}")
    
    if not best_layers:
        print("  (No single head strongly links she->doctor in this model)")
        print("  This is normal - attention patterns are distributed across heads.")

## Part 6: The Concatenation and Projection Step

A critical step that's often glossed over: after computing attention for each head, we **concatenate** the outputs and apply a **linear projection**.

Why? Because each head operates in a different subspace. The output projection $W^O$ lets the model mix information across these subspaces.

In [None]:
# Demonstrate concatenation + projection visually
d_model = 8
n_heads = 4
d_k = d_model // n_heads  # 2
seq_len = 3

# Simulate head outputs
head_outputs = [torch.randn(1, seq_len, d_k) for _ in range(n_heads)]

print("Head outputs (each in its own subspace):")
for i, h in enumerate(head_outputs):
    print(f"  Head {i}: shape {h.shape}, values: {h[0, 0].tolist()}")

# Concatenate
concatenated = torch.cat(head_outputs, dim=-1)
print(f"\nAfter concatenation: shape {concatenated.shape}")
print(f"  Values: {concatenated[0, 0].tolist()}")

# Output projection
W_o = nn.Linear(d_model, d_model, bias=False)
projected = W_o(concatenated)
print(f"\nAfter W_o projection: shape {projected.shape}")
print(f"  Values: {projected[0, 0].tolist()}")
print(f"\nThe output projection mixes information from all {n_heads} heads!")

In [None]:
# Visualize the concatenation + projection process
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# Individual head outputs
head_data = np.random.randn(n_heads, d_k)
axes[0].imshow(head_data, cmap='RdBu', aspect='auto')
axes[0].set_title(f'Individual Head Outputs\n({n_heads} heads x {d_k} dims)', fontweight='bold')
axes[0].set_ylabel('Head')
axes[0].set_xlabel('Dimension')
axes[0].set_yticks(range(n_heads))
axes[0].set_yticklabels([f'Head {i}' for i in range(n_heads)])

# Concatenated
concat_data = head_data.reshape(1, d_model)
axes[1].imshow(concat_data, cmap='RdBu', aspect='auto')
axes[1].set_title(f'After Concatenation\n(1 x {d_model} dims)', fontweight='bold')
axes[1].set_xlabel('Dimension')
axes[1].set_yticks([])

# Add dividers to show head boundaries
for i in range(1, n_heads):
    axes[1].axvline(x=i * d_k - 0.5, color='white', linewidth=2)

# After projection
W = np.random.randn(d_model, d_model) * 0.3
projected_data = (concat_data @ W).reshape(1, d_model)
axes[2].imshow(projected_data, cmap='RdBu', aspect='auto')
axes[2].set_title(f'After W_o Projection\n(mixed across heads)', fontweight='bold')
axes[2].set_xlabel('Dimension')
axes[2].set_yticks([])

plt.suptitle('Multi-Head Attention: Concat + Project', fontsize=14, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

## Part 7: Single-Head vs Multi-Head Comparison

Let's train both on a simple sequence modeling task and compare their ability to capture patterns.

In [None]:
# Simple task: predict the next token in patterned sequences
# Pattern: A B C A B C A B C ... (repeating)
# This requires the model to attend to tokens 3 positions back

def generate_pattern_data(n_samples=500, seq_len=12, vocab_size=5):
    """Generate repeating pattern sequences."""
    data = []
    targets = []
    
    for _ in range(n_samples):
        # Random pattern of length 3
        pattern = np.random.randint(0, vocab_size, size=3)
        # Repeat to fill sequence
        seq = np.tile(pattern, seq_len // 3 + 1)[:seq_len + 1]
        data.append(seq[:-1])
        targets.append(seq[1:])
    
    return torch.LongTensor(np.array(data)), torch.LongTensor(np.array(targets))


class SimpleAttentionModel(nn.Module):
    """Minimal model: embedding -> attention -> output."""
    
    def __init__(self, vocab_size, d_model, n_heads, seq_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(seq_len, d_model)
        
        if n_heads == 1:
            self.attention = SingleHeadAttention(d_model)
        else:
            self.attention = MultiHeadAttention(d_model, n_heads)
        
        self.output = nn.Linear(d_model, vocab_size)
        self.n_heads = n_heads
    
    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        
        # Create causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).unsqueeze(0)
        if self.n_heads > 1:
            mask = mask.unsqueeze(1)  # (1, 1, seq_len, seq_len) for multi-head
        
        x = self.embedding(x) + self.pos_embedding(positions)
        x, attn = self.attention(x, mask=mask)
        logits = self.output(x)
        
        return logits, attn


# Create data
vocab_size = 5
d_model = 32
seq_len = 12

X_train, Y_train = generate_pattern_data(500, seq_len, vocab_size)
X_test, Y_test = generate_pattern_data(100, seq_len, vocab_size)

print(f"Training data: {X_train.shape}")
print(f"Sample sequence: {X_train[0].tolist()}")
print(f"Sample target:   {Y_train[0].tolist()}")
print(f"Pattern: {X_train[0][:3].tolist()} repeats")

In [None]:
def train_model(model, X_train, Y_train, epochs=100, lr=0.005):
    """Train a model and return loss history."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []
    
    for epoch in range(epochs):
        model.train()
        logits, _ = model(X_train)
        
        # Reshape for loss: (batch * seq_len, vocab_size) vs (batch * seq_len,)
        loss = criterion(logits.view(-1, vocab_size), Y_train.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if (epoch + 1) % 20 == 0:
            # Calculate accuracy
            model.eval()
            with torch.no_grad():
                test_logits, _ = model(X_test)
                preds = test_logits.argmax(dim=-1)
                acc = (preds == Y_test).float().mean().item()
            print(f"  Epoch {epoch+1:3d}: Loss={loss.item():.4f}, Test Acc={acc:.4f}")
    
    return losses


# Train single-head model
print("Training Single-Head Attention:")
torch.manual_seed(42)
model_1h = SimpleAttentionModel(vocab_size, d_model, n_heads=1, seq_len=seq_len)
losses_1h = train_model(model_1h, X_train, Y_train)

print("\nTraining Multi-Head Attention (4 heads):")
torch.manual_seed(42)
model_4h = SimpleAttentionModel(vocab_size, d_model, n_heads=4, seq_len=seq_len)
losses_4h = train_model(model_4h, X_train, Y_train)

print("\nTraining Multi-Head Attention (8 heads):")
torch.manual_seed(42)
model_8h = SimpleAttentionModel(vocab_size, d_model, n_heads=8, seq_len=seq_len)
losses_8h = train_model(model_8h, X_train, Y_train)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Loss curves
axes[0].plot(losses_1h, label='1 Head', alpha=0.8, linewidth=2)
axes[0].plot(losses_4h, label='4 Heads', alpha=0.8, linewidth=2)
axes[0].plot(losses_8h, label='8 Heads', alpha=0.8, linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss: Single vs Multi-Head', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Final accuracy comparison
accs = []
for model_name, model in [('1 Head', model_1h), ('4 Heads', model_4h), ('8 Heads', model_8h)]:
    model.eval()
    with torch.no_grad():
        logits, _ = model(X_test)
        preds = logits.argmax(dim=-1)
        acc = (preds == Y_test).float().mean().item()
        accs.append(acc)

bars = axes[1].bar(['1 Head', '4 Heads', '8 Heads'], accs, 
                    color=['#e74c3c', '#3498db', '#2ecc71'], edgecolor='black')
axes[1].set_ylabel('Test Accuracy', fontsize=12)
axes[1].set_title('Test Accuracy Comparison', fontsize=13, fontweight='bold')
axes[1].set_ylim(0, 1.05)

for bar, acc in zip(bars, accs):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                f'{acc:.1%}', ha='center', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Visualize learned attention patterns from the trained multi-head model
model_4h.eval()
with torch.no_grad():
    test_seq = X_test[0:1]  # Single sequence
    _, attn = model_4h(test_seq)

tokens_str = [str(t) for t in test_seq[0].tolist()]

fig, axes = plt.subplots(1, 4, figsize=(24, 5))
attn_np = attn[0].detach().numpy()  # (n_heads, seq_len, seq_len)

for h in range(4):
    sns.heatmap(
        attn_np[h],
        xticklabels=tokens_str,
        yticklabels=tokens_str,
        cmap='Blues',
        ax=axes[h],
        annot=True,
        fmt='.2f',
        vmin=0,
        vmax=1
    )
    axes[h].set_title(f'Trained Head {h}', fontsize=12, fontweight='bold')

plt.suptitle(f'Trained 4-Head Model: Learned Patterns\nSequence: {tokens_str} (pattern of 3 repeats)',
             fontsize=14, fontweight='bold', y=1.08)
plt.tight_layout()
plt.show()

print("Look for heads attending to positions 3 steps back (the pattern period).")
print("Different heads may specialize: some track the immediate context,")
print("others learn the periodic structure.")

## Part 8: Attention Head Diversity Analysis

One way to measure whether heads are truly specializing is to compute the **similarity** between attention patterns of different heads. Lower similarity means more diverse patterns.

In [None]:
# Compute pairwise cosine similarity between heads
def head_diversity(attention_weights):
    """Compute pairwise similarity between attention heads."""
    # attention_weights: (n_heads, seq_len, seq_len)
    n_heads = attention_weights.shape[0]
    
    # Flatten each head's attention to a vector
    flat = attention_weights.reshape(n_heads, -1)
    
    # Normalize
    norms = np.linalg.norm(flat, axis=1, keepdims=True)
    flat_norm = flat / (norms + 1e-8)
    
    # Pairwise cosine similarity
    similarity = flat_norm @ flat_norm.T
    
    return similarity


# Compare head diversity in GPT-2 across layers
text = "The cat sat on the mat because it was tired"
tokens_gpt2, attentions_gpt2 = get_attention_patterns(text, model, tokenizer)

fig, axes = plt.subplots(2, 3, figsize=(18, 11))
layers_to_show = [0, 2, 5, 7, 9, 11]

for idx, (layer, ax) in enumerate(zip(layers_to_show, axes.flatten())):
    layer_attn = attentions_gpt2[layer][0].detach().numpy()
    sim = head_diversity(layer_attn)
    
    sns.heatmap(sim, cmap='RdYlGn_r', vmin=0, vmax=1, 
                annot=True, fmt='.2f', ax=ax, square=True)
    ax.set_title(f'Layer {layer}', fontsize=12, fontweight='bold')
    ax.set_xlabel('Head')
    ax.set_ylabel('Head')

plt.suptitle('Head Diversity: Pairwise Cosine Similarity Between Attention Heads\n(Lower = more diverse = better)',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Green (low similarity) = heads are learning different patterns (good!)")
print("Red (high similarity) = heads are redundant (wasteful)")

## Part 9: Computational Analysis

Multi-head attention is elegant because it's **computationally equivalent** to single-head attention but more expressive.

In [None]:
import time

# Benchmark: single-head vs multi-head speed
d_model = 256
seq_len = 128
batch_size = 16
n_runs = 50

# Models
sha = SingleHeadAttention(d_model)
mha_4 = MultiHeadAttention(d_model, 4)
mha_8 = MultiHeadAttention(d_model, 8)
mha_16 = MultiHeadAttention(d_model, 16)

x = torch.randn(batch_size, seq_len, d_model)

results = {}

for name, model in [('1 Head', sha), ('4 Heads', mha_4), 
                     ('8 Heads', mha_8), ('16 Heads', mha_16)]:
    model.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(5):
            model(x)
    
    # Time
    times = []
    with torch.no_grad():
        for _ in range(n_runs):
            start = time.perf_counter()
            model(x)
            times.append(time.perf_counter() - start)
    
    avg_time = np.mean(times) * 1000  # ms
    results[name] = avg_time
    params = sum(p.numel() for p in model.parameters())
    print(f"{name:>8}: {avg_time:.2f} ms  ({params:,} params)")

print("\nAll configurations have the same parameter count!")
print("Speed difference is minimal because the math is the same.")

In [None]:
# Visualize the speed comparison
fig, ax = plt.subplots(figsize=(10, 5))

names = list(results.keys())
times = list(results.values())
colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12']

bars = ax.bar(names, times, color=colors, edgecolor='black', linewidth=1.2)
ax.set_ylabel('Time (ms)', fontsize=12)
ax.set_title('Inference Time: Single vs Multi-Head Attention\n(Same parameters, similar speed)',
             fontsize=13, fontweight='bold')

for bar, t in zip(bars, times):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
            f'{t:.2f} ms', ha='center', fontsize=11, fontweight='bold')

ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

## Part 10: Summary Visualization

Let's create a comprehensive visualization showing the full multi-head attention pipeline.

In [None]:
# Final comprehensive visualization
text = "The quick brown fox jumps over the lazy dog"
tokens_final, attentions_final = get_attention_patterns(text, model, tokenizer)
clean_final = [t.replace('\u0120', ' ').strip() for t in tokens_final]

# Show layer 5 (middle layer) - all 12 heads
layer = 5
layer_attn = attentions_final[layer][0].detach().numpy()

fig, axes = plt.subplots(3, 4, figsize=(24, 16))

for h in range(12):
    row, col = h // 4, h % 4
    ax = axes[row, col]
    
    sns.heatmap(
        layer_attn[h],
        xticklabels=clean_final,
        yticklabels=clean_final if col == 0 else False,
        cmap='Blues',
        ax=ax,
        vmin=0,
        vmax=0.5,
        cbar=col == 3
    )
    ax.set_title(f'Head {h}', fontsize=11, fontweight='bold')
    ax.tick_params(axis='x', rotation=45, labelsize=7)
    ax.tick_params(axis='y', rotation=0, labelsize=7)

plt.suptitle(f'GPT-2 Layer {layer}: All 12 Attention Heads\n"{text}"\n'
             f'Each head captures a different aspect of the relationships between tokens',
             fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

---

## Key Takeaways

1. **Single-head attention** compresses all relationships into one attention pattern. It must trade off between different types of information (syntax, semantics, position).

2. **Multi-head attention** splits the model dimension into parallel subspaces. Each head gets $d_k = d_{model} / h$ dimensions and can specialize in capturing different patterns.

3. **No extra parameters**: Multi-head attention has the **same** parameter count as single-head. The improvement comes from better utilization of the same capacity.

4. **Head specialization**: In trained models, different heads genuinely learn different patterns:
   - Local/positional attention (nearby tokens)
   - Syntactic attention (subject-verb)
   - Semantic attention (co-reference, entity tracking)
   - Delimiter/structural attention (punctuation, special tokens)

5. **Concatenation + Projection**: After parallel attention, outputs are concatenated and projected through $W^O$, which lets the model mix information across different heads' subspaces.

6. **For inference engineering**: Understanding multi-head attention matters because:
   - KV cache stores keys and values **per head** (memory planning)
   - Some heads can be pruned with minimal quality loss (optimization)
   - Grouped Query Attention (GQA) shares KV across heads to reduce memory
   
---

*Next: We'll explore arithmetic intensity and the roofline model to understand when attention is compute-bound vs memory-bound.*