# Build DeepSeek from Scratch - Multi-Head Attention Implementation

## Overview
This notebook provides a comprehensive mathematical walkthrough and Python implementation of multi-head attention. We'll see every matrix operation step-by-step, understand the dimensional transformations, and implement the complete multi-head attention class.

## Learning Objectives
- Master the mathematical implementation of multi-head attention
- Understand every matrix multiplication and dimensional transformation
- Learn how to reshape and group matrices for multiple heads
- Implement the complete multi-head attention class in Python
- Bridge the gap between conceptual understanding and practical coding

## Prerequisites
- Understanding of self-attention mechanism
- Knowledge of causal attention and masking
- Familiarity with the conceptual overview of multi-head attention

## Recap: Multi-Head Attention Concept

### The Core Problem
**Single-head attention limitation**: Can only capture one perspective of input text

**Example**: "The artist painted the portrait of a woman with a brush"
- **Perspective 1**: Artist uses brush to paint
- **Perspective 2**: Woman in portrait holds brush

### The Solution Strategy
1. **Split** query, key, value matrices into multiple heads
2. **Process** each head independently to capture different perspectives  
3. **Merge** results to create richer context vectors
4. **Maintain** same output dimensions as single-head attention

### Key Parameters for Today's Implementation
- **Input tokens**: 3 (simplified example)
- **Input dimension (d_in)**: 6
- **Output dimension (d_out)**: 6
- **Number of heads**: 2
- **Head dimension**: d_out / num_heads = 6 / 2 = 3

## Understanding the Key Parameters with Examples

Let's break down each parameter with concrete, intuitive examples:

### 1. **Input tokens = 3** (Sequence Length)
This means we're processing a sentence with 3 words/tokens.

**Example sentence**: "The cat sleeps"
- Token 1: "The" 
- Token 2: "cat"
- Token 3: "sleeps"

**In practice**: Real sentences have hundreds or thousands of tokens, but we use 3 for simplicity.

### 2. **Input dimension (d_in) = 6** (Embedding Size)
Each token is represented by a 6-dimensional vector (usually 512-4096 in real models).

**Example for "The"**:
```
"The" → [0.2, -0.1, 0.8, 0.3, -0.5, 0.7]  # 6 numbers
```

**What these numbers represent**:
- Position 1: Maybe "definiteness" (0.2 = somewhat definite)
- Position 2: Maybe "animacy" (-0.1 = not animate)
- Position 3: Maybe "frequency" (0.8 = very common word)
- Position 4: Maybe "grammatical role" (0.3 = article)
- Position 5: Maybe "sentiment" (-0.5 = neutral)
- Position 6: Maybe "semantic category" (0.7 = function word)

### 3. **Output dimension (d_out) = 6** (Context Size)
After attention, each token still has 6 dimensions, but now enriched with context.

**Before attention** - "cat" in isolation:
```
[0.5, 0.9, 0.1, -0.2, 0.3, 0.8]  # Just "cat" features
```

**After attention** - "cat" with context from "The" and "sleeps":
```
[0.6, 0.8, 0.2, 0.1, 0.4, 0.7]   # "cat" + context from other words
```

### 4. **Number of heads = 2** (Multiple Perspectives)
We split the attention into 2 different "views" of the same sentence.

**Head 1 might focus on**: Grammar relationships
- "The" → "cat" (determiner-noun relationship)
- "cat" → "sleeps" (subject-verb relationship)

**Head 2 might focus on**: Semantic relationships  
- "The" → "sleeps" (who is doing the sleeping?)
- "cat" → "sleeps" (what kind of sleeping? peaceful, deep, etc.)

### 5. **Head dimension = 3** (d_out ÷ num_heads)
Each head works with 3 dimensions instead of all 6.

**Original 6 dimensions** split into **2 heads of 3 dimensions each**:

**Head 1 gets dimensions 1-3**:
```
"The": [0.2, -0.1, 0.8] → focuses on grammar
"cat": [0.5, 0.9, 0.1] → focuses on grammar  
"sleeps": [0.1, 0.4, 0.6] → focuses on grammar
```

**Head 2 gets dimensions 4-6**:
```
"The": [0.3, -0.5, 0.7] → focuses on meaning
"cat": [-0.2, 0.3, 0.8] → focuses on meaning
"sleeps": [0.2, 0.1, 0.9] → focuses on meaning
```

### Real-World Analogy
Think of it like **two people reading the same sentence**:

**Person 1 (Head 1)**: Grammar expert
- Notices: "The" is an article, "cat" is a noun, "sleeps" is a verb
- Focuses on: Sentence structure, word roles, syntax

**Person 2 (Head 2)**: Meaning expert  
- Notices: This is about an animal, the action is peaceful, it's present tense
- Focuses on: Semantics, concepts, relationships

**Final result**: Combine both perspectives for richer understanding!

### Why These Numbers?
- **Small numbers** (3 tokens, 6 dimensions) make it easy to follow the math
- **Real models** use much larger numbers (1000+ tokens, 512+ dimensions)
- **Same principles** apply regardless of size

## Understanding Key Parameters with Examples

### 1. **Input Tokens = 3** (Sequence Length)

**What it means**: We're processing 3 words/tokens at once

**Real Example**:
```
Input sentence: "The cat sat"
Token 1: "The"
Token 2: "cat"  
Token 3: "sat"
```

**Why 3 tokens?**
- Small example for easy understanding
- In real models: 512, 1024, 2048+ tokens
- Each token represents one word/subword

### 2. **Input Dimension (d_in) = 6** (Embedding Size)

**What it means**: Each token is represented by a 6-dimensional vector

**Real Example**:
```
"The" → [0.1, -0.3, 0.7, 0.2, -0.1, 0.5]  # 6 numbers
"cat" → [0.4, 0.8, -0.2, 0.3, 0.6, -0.4]  # 6 numbers  
"sat" → [-0.2, 0.1, 0.9, -0.5, 0.3, 0.7]  # 6 numbers
```

**Why 6 dimensions?**
- Small example for easy math
- In real models: 768 (BERT), 1024 (GPT-2), 4096+ (GPT-3)
- Each dimension captures different semantic features

### 3. **Output Dimension (d_out) = 6** (Context Size)

**What it means**: Each token produces a 6-dimensional context vector

**Real Example**:
```
Input:  "The" → [0.1, -0.3, 0.7, 0.2, -0.1, 0.5]
Output: "The" → [0.3, 0.1, -0.2, 0.8, 0.4, -0.1]  # New context-aware representation
```

**Why same as input?**
- Common in transformers (residual connections work better)
- Could be different (e.g., d_in=512, d_out=1024)
- Output contains richer context information

### 4. **Number of Heads = 2** (Multiple Perspectives)

**What it means**: We split attention into 2 different "viewpoints"

**Real Example with "The cat sat"**:
```
Head 1 might focus on: Grammar relationships
- "The" pays attention to "cat" (article → noun)
- "cat" pays attention to "sat" (subject → verb)

Head 2 might focus on: Semantic relationships  
- "cat" pays attention to "sat" (who does what)
- "sat" pays attention to "The cat" (complete subject)
```

**Why 2 heads?**
- Simple example for learning
- Real models: 8, 12, 16+ heads
- Each head learns different patterns

### 5. **Head Dimension = 3** (Per-Head Size)

**What it means**: Each head processes 3-dimensional vectors

**Mathematical relationship**:
```
head_dim = d_out / num_heads = 6 / 2 = 3
```

**Real Example**:
```
Original "The" vector: [0.1, -0.3, 0.7, 0.2, -0.1, 0.5]

Split into 2 heads:
Head 1: [0.1, -0.3, 0.7]  # First 3 dimensions
Head 2: [0.2, -0.1, 0.5]  # Last 3 dimensions
```

**Why this split?**
- Allows parallel processing
- Each head has smaller, focused representation
- Final output recombines all heads

### Visual Summary

```
Input Matrix (3 tokens × 6 dimensions):
[0.1, -0.3, 0.7, 0.2, -0.1, 0.5]  ← "The"
[0.4,  0.8,-0.2, 0.3,  0.6,-0.4]  ← "cat"
[-0.2, 0.1, 0.9,-0.5,  0.3, 0.7]  ← "sat"

↓ Multi-Head Attention Processing ↓

Output Matrix (3 tokens × 6 dimensions):
[0.3,  0.1,-0.2, 0.8,  0.4,-0.1]  ← "The" (context-aware)
[0.5, -0.3, 0.6, 0.2, -0.7, 0.4]  ← "cat" (context-aware)
[0.1,  0.9,-0.1, 0.3,  0.5,-0.2]  ← "sat" (context-aware)
```

**Key Insight**: Same input/output dimensions, but output vectors now contain information from ALL tokens in the sequence, processed through multiple attention heads!

## Step 1: Input Setup and Initialization

### Input Embedding Matrix Structure
```python
# Input shape: [batch_size, num_tokens, d_in]
X = [1, 3, 6]  # 1 batch, 3 tokens, 6-dimensional embeddings
```

### Visual Representation
```
Token 1: [x₁₁, x₁₂, x₁₃, x₁₄, x₁₅, x₁₆]  # Input embedding
Token 2: [x₂₁, x₂₂, x₂₃, x₂₄, x₂₅, x₂₆]  # Input embedding  
Token 3: [x₃₁, x₃₂, x₃₃, x₃₄, x₃₅, x₃₆]  # Input embedding
```

### Key Dimensions to Track
- **Batch size**: 1 (can be extended to multiple batches)
- **Number of tokens**: 3 (sequence length)
- **Input dimension**: 6 (d_in)
- **Output dimension**: 6 (d_out)

### Input Embedding Composition
**Recall**: Input embedding = Token embedding + Positional embedding
- Each token gets its own "uniform" representation
- Contains both semantic and positional information

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Set up example input
batch_size = 1
num_tokens = 3
d_in = 6
d_out = 6
num_heads = 2
head_dim = d_out // num_heads  # = 3

# Create example input embedding matrix
# Shape: [batch_size, num_tokens, d_in]
X = torch.randn(batch_size, num_tokens, d_in)
print(f"Input X shape: {X.shape}")
print(f"X:\n{X}")

print(f"\nKey parameters:")
print(f"- Input dimension (d_in): {d_in}")
print(f"- Output dimension (d_out): {d_out}")
print(f"- Number of heads: {num_heads}")
print(f"- Head dimension: {head_dim}")

Input X shape: torch.Size([1, 3, 6])
X:
tensor([[[ 1.0155, -0.5620,  0.1173,  0.5436, -0.4497, -0.4543],
         [-0.3182,  0.6049, -1.0629, -0.9308, -1.8998,  1.4582],
         [-0.6771, -1.4818,  0.1613,  1.3851,  2.0653,  0.8286]]])

Key parameters:
- Input dimension (d_in): 6
- Output dimension (d_out): 6
- Number of heads: 2
- Head dimension: 3


## Step 2: Initialize Trainable Weight Matrices

### Weight Matrix Dimensions
All weight matrices have shape `[d_in, d_out]` = `[6, 6]`

```python
WQ: [6 × 6]  # Query weight matrix
WK: [6 × 6]  # Key weight matrix  
WV: [6 × 6]  # Value weight matrix
```

### Key Insight
- **Same parameter count** as single-head attention
- **Different organization**: Will be split across heads later
- **Random initialization**: Optimized through backpropagation

### Matrix Multiplication Overview
```python
# Standard linear transformations
Q = X @ WQ  # [1, 3, 6] @ [6, 6] = [1, 3, 6]
K = X @ WK  # [1, 3, 6] @ [6, 6] = [1, 3, 6]  
V = X @ WV  # [1, 3, 6] @ [6, 6] = [1, 3, 6]
```

### Dimensional Analysis
- **Input**: [batch_size, num_tokens, d_in]
- **Output**: [batch_size, num_tokens, d_out]
- **Space transformation**: From input dimension space to output dimension space

In [2]:
# Initialize trainable weight matrices
# Using nn.Linear with bias=False for optimized initialization
W_query = nn.Linear(d_in, d_out, bias=False)
W_key = nn.Linear(d_in, d_out, bias=False)
W_value = nn.Linear(d_in, d_out, bias=False)

# Generate Q, K, V matrices
Q = W_query(X)  # [1, 3, 6]
K = W_key(X)    # [1, 3, 6]
V = W_value(X)  # [1, 3, 6]

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")

print(f"\nQ matrix:\n{Q}")
print(f"\nK matrix:\n{K}")
print(f"\nV matrix:\n{V}")

# Verify dimensions match expected output space
assert Q.shape == (batch_size, num_tokens, d_out)
assert K.shape == (batch_size, num_tokens, d_out)
assert V.shape == (batch_size, num_tokens, d_out)
print("\n✓ All Q, K, V matrices have correct dimensions")

Q shape: torch.Size([1, 3, 6])
K shape: torch.Size([1, 3, 6])
V shape: torch.Size([1, 3, 6])

Q matrix:
tensor([[[ 0.2434,  0.4607, -0.5537, -0.5116, -0.0451,  0.1184],
         [-0.5975, -0.5909, -0.6584, -0.2954, -0.6365, -0.7123],
         [ 0.4812, -0.1247,  0.3195,  1.0179,  0.8944,  0.8886]]],
       grad_fn=<UnsafeViewBackward0>)

K matrix:
tensor([[[-0.3222,  0.3691,  0.3103, -0.5221, -0.0345,  0.4966],
         [-0.5679,  0.7716,  0.3563, -0.4399,  1.3386,  0.2529],
         [ 0.5660,  0.5104, -0.6236,  1.3696, -0.8633, -0.0945]]],
       grad_fn=<UnsafeViewBackward0>)

V matrix:
tensor([[[-0.8460,  0.2317,  0.0061, -0.1790,  0.0405,  0.0707],
         [ 1.4305, -0.4608,  1.1821,  1.2324,  0.0492, -0.3842],
         [-0.3349,  1.5204, -1.7049, -0.3751,  0.8196,  0.6283]]],
       grad_fn=<UnsafeViewBackward0>)

✓ All Q, K, V matrices have correct dimensions


## Step 3: Reshape for Multiple Heads (Unrolling)

### The Critical Transformation
**Goal**: Split the output dimension across multiple heads

**Before**: `[batch_size, num_tokens, d_out]` = `[1, 3, 6]`
**After**: `[batch_size, num_tokens, num_heads, head_dim]` = `[1, 3, 2, 3]`

### Visual Understanding of Unrolling

**Original Q matrix** (6 columns):
```
Token 1: [q₁₁, q₁₂, q₁₃, q₁₄, q₁₅, q₁₆]
Token 2: [q₂₁, q₂₂, q₂₃, q₂₄, q₂₅, q₂₆]
Token 3: [q₃₁, q₃₂, q₃₃, q₃₄, q₃₅, q₃₆]
```

**After unrolling** (2 heads × 3 dimensions each):
```
Token 1, Head 1: [q₁₁, q₁₂, q₁₃]  Token 1, Head 2: [q₁₄, q₁₅, q₁₆]
Token 2, Head 1: [q₂₁, q₂₂, q₂₃]  Token 2, Head 2: [q₂₄, q₂₅, q₂₆]
Token 3, Head 1: [q₃₁, q₃₂, q₃₃]  Token 3, Head 2: [q₃₄, q₃₅, q₃₆]
```

### Dimension Formula
```
head_dim = d_out / num_heads
[B, T, D] → [B, T, H, head_dim]
```

### Why This Works
- **No parameter addition**: Just reorganization
- **Head isolation**: Each head gets its own subspace
- **Parallel processing**: All heads computed simultaneously

In [3]:
# Step 3: Unroll last dimension to create multiple heads
# Transform from [B, T, D] to [B, T, H, head_dim]

Q_reshaped = Q.view(batch_size, num_tokens, num_heads, head_dim)
K_reshaped = K.view(batch_size, num_tokens, num_heads, head_dim)
V_reshaped = V.view(batch_size, num_tokens, num_heads, head_dim)

print(f"After unrolling:")
print(f"Q_reshaped shape: {Q_reshaped.shape}")
print(f"K_reshaped shape: {K_reshaped.shape}")
print(f"V_reshaped shape: {V_reshaped.shape}")

print(f"\nQ_reshaped:\n{Q_reshaped}")

# Verify the reshape preserved the data
print(f"\nVerification - Original Q flattened equals reshaped Q flattened:")
print(f"Original: {Q.flatten()[:6]}")
print(f"Reshaped: {Q_reshaped.flatten()[:6]}")
print(f"Equal: {torch.allclose(Q.flatten(), Q_reshaped.flatten())}")

# Visualize the head separation
print(f"\nHead separation visualization:")
print(f"Token 1, Head 1: {Q_reshaped[0, 0, 0, :]}")  # [q11, q12, q13]
print(f"Token 1, Head 2: {Q_reshaped[0, 0, 1, :]}")  # [q14, q15, q16]
print(f"Token 2, Head 1: {Q_reshaped[0, 1, 0, :]}")  # [q21, q22, q23]
print(f"Token 2, Head 2: {Q_reshaped[0, 1, 1, :]}")  # [q24, q25, q26]

After unrolling:
Q_reshaped shape: torch.Size([1, 3, 2, 3])
K_reshaped shape: torch.Size([1, 3, 2, 3])
V_reshaped shape: torch.Size([1, 3, 2, 3])

Q_reshaped:
tensor([[[[ 0.2434,  0.4607, -0.5537],
          [-0.5116, -0.0451,  0.1184]],

         [[-0.5975, -0.5909, -0.6584],
          [-0.2954, -0.6365, -0.7123]],

         [[ 0.4812, -0.1247,  0.3195],
          [ 1.0179,  0.8944,  0.8886]]]], grad_fn=<ViewBackward0>)

Verification - Original Q flattened equals reshaped Q flattened:
Original: tensor([ 0.2434,  0.4607, -0.5537, -0.5116, -0.0451,  0.1184],
       grad_fn=<SliceBackward0>)
Reshaped: tensor([ 0.2434,  0.4607, -0.5537, -0.5116, -0.0451,  0.1184],
       grad_fn=<SliceBackward0>)
Equal: True

Head separation visualization:
Token 1, Head 1: tensor([ 0.2434,  0.4607, -0.5537], grad_fn=<SliceBackward0>)
Token 1, Head 2: tensor([-0.5116, -0.0451,  0.1184], grad_fn=<SliceBackward0>)
Token 2, Head 1: tensor([-0.5975, -0.5909, -0.6584], grad_fn=<SliceBackward0>)
Token 2, Head 2:

## Step 4: Transpose for Head-Grouped Processing

### The Grouping Problem
**Current grouping**: By tokens `[B, T, H, head_dim]`
- Token 1: [Head 1, Head 2]
- Token 2: [Head 1, Head 2]  
- Token 3: [Head 1, Head 2]

**Desired grouping**: By heads `[B, H, T, head_dim]`
- Head 1: [Token 1, Token 2, Token 3]
- Head 2: [Token 1, Token 2, Token 3]

### Why Transpose is Necessary
**For efficient computation**:
- Q1 @ K1.T (Head 1 processing)
- Q2 @ K2.T (Head 2 processing)

**Need clear head separation**:
- All Q1 data together
- All K1 data together
- All Q2 data together
- All K2 data together

### Transpose Operation
```python
# Swap dimensions 1 and 2
.transpose(1, 2)  # [B, T, H, head_dim] → [B, H, T, head_dim]
```

### Visual Result After Transpose
```
Head 1:
  Token 1: [q₁₁, q₁₂, q₁₃]
  Token 2: [q₂₁, q₂₂, q₂₃]  
  Token 3: [q₃₁, q₃₂, q₃₃]

Head 2:
  Token 1: [q₁₄, q₁₅, q₁₆]
  Token 2: [q₂₄, q₂₅, q₂₆]
  Token 3: [q₃₄, q₃₅, q₃₆]
```

In [4]:
# Step 4: Transpose to group by heads instead of tokens
# Transform from [B, T, H, head_dim] to [B, H, T, head_dim]

Q_heads = Q_reshaped.transpose(1, 2)  # [1, 2, 3, 3]
K_heads = K_reshaped.transpose(1, 2)  # [1, 2, 3, 3]
V_heads = V_reshaped.transpose(1, 2)  # [1, 2, 3, 3]

print(f"After transpose (grouped by heads):")
print(f"Q_heads shape: {Q_heads.shape}")
print(f"K_heads shape: {K_heads.shape}")
print(f"V_heads shape: {V_heads.shape}")

print(f"\nQ_heads structure:")
print(f"Q_heads:\n{Q_heads}")

print(f"\nHead separation:")
print(f"Q1 (Head 1): \n{Q_heads[0, 0, :, :]}")  # All tokens for head 1
print(f"Q2 (Head 2): \n{Q_heads[0, 1, :, :]}")  # All tokens for head 2

print(f"\nK1 (Head 1): \n{K_heads[0, 0, :, :]}")  # All tokens for head 1
print(f"K2 (Head 2): \n{K_heads[0, 1, :, :]}")  # All tokens for head 2

# Verify we now have clear head groupings
print(f"\nDimension verification:")
print(f"- Batch size: {Q_heads.shape[0]}")
print(f"- Number of heads: {Q_heads.shape[1]}")
print(f"- Number of tokens: {Q_heads.shape[2]}")
print(f"- Head dimension: {Q_heads.shape[3]}")

After transpose (grouped by heads):
Q_heads shape: torch.Size([1, 2, 3, 3])
K_heads shape: torch.Size([1, 2, 3, 3])
V_heads shape: torch.Size([1, 2, 3, 3])

Q_heads structure:
Q_heads:
tensor([[[[ 0.2434,  0.4607, -0.5537],
          [-0.5975, -0.5909, -0.6584],
          [ 0.4812, -0.1247,  0.3195]],

         [[-0.5116, -0.0451,  0.1184],
          [-0.2954, -0.6365, -0.7123],
          [ 1.0179,  0.8944,  0.8886]]]], grad_fn=<TransposeBackward0>)

Head separation:
Q1 (Head 1): 
tensor([[ 0.2434,  0.4607, -0.5537],
        [-0.5975, -0.5909, -0.6584],
        [ 0.4812, -0.1247,  0.3195]], grad_fn=<SliceBackward0>)
Q2 (Head 2): 
tensor([[-0.5116, -0.0451,  0.1184],
        [-0.2954, -0.6365, -0.7123],
        [ 1.0179,  0.8944,  0.8886]], grad_fn=<SliceBackward0>)

K1 (Head 1): 
tensor([[-0.3222,  0.3691,  0.3103],
        [-0.5679,  0.7716,  0.3563],
        [ 0.5660,  0.5104, -0.6236]], grad_fn=<SliceBackward0>)
K2 (Head 2): 
tensor([[-0.5221, -0.0345,  0.4966],
        [-0.4399,  1

## Step 5: Compute Attention Scores for Each Head

### The Parallel Computation
**Objective**: Compute attention scores for each head independently

**Mathematical operation**:
```python
# Head 1: Q1 @ K1.T
attention_scores_1 = Q1 @ K1.transpose(-2, -1)

# Head 2: Q2 @ K2.T  
attention_scores_2 = Q2 @ K2.transpose(-2, -1)
```

### Batch Operation Magic
**Instead of separate operations**, PyTorch handles all heads simultaneously:
```python
# All heads at once
attention_scores = Q_heads @ K_heads.transpose(-2, -1)
```

### Dimensional Analysis
**Input dimensions**:
- Q_heads: `[B, H, T, head_dim]` = `[1, 2, 3, 3]`
- K_heads.T: `[B, H, head_dim, T]` = `[1, 2, 3, 3]`

**Matrix multiplication**:
- `[B, H, T, head_dim] @ [B, H, head_dim, T]`
- Result: `[B, H, T, T]` = `[1, 2, 3, 3]`

### Key Insight
**Attention scores dimensions**: Always `[num_tokens × num_tokens]` regardless of head_dim
- Represents token-to-token relationships
- Same for each head, but different values
- Each head captures different perspective of relationships

In [5]:
# Step 5: Compute attention scores for each head
# Q_heads @ K_heads.transpose(-2, -1)

attention_scores = Q_heads @ K_heads.transpose(-2, -1)

print(f"Attention scores shape: {attention_scores.shape}")
print(f"Expected: [batch_size, num_heads, num_tokens, num_tokens]")
print(f"Actual: {list(attention_scores.shape)}")

print(f"\nAttention scores:\n{attention_scores}")

print(f"\nHead 1 attention scores:")
print(f"{attention_scores[0, 0, :, :]}")

print(f"\nHead 2 attention scores:")
print(f"{attention_scores[0, 1, :, :]}")

# Verify dimensions
assert attention_scores.shape == (batch_size, num_heads, num_tokens, num_tokens)
print(f"\n✓ Attention scores have correct dimensions")

# Manual verification for Head 1
Q1 = Q_heads[0, 0, :, :]  # [3, 3]
K1 = K_heads[0, 0, :, :]  # [3, 3]
manual_scores_1 = Q1 @ K1.T
print(f"\nManual Head 1 computation verification:")
print(f"PyTorch result:\n{attention_scores[0, 0, :, :]}")
print(f"Manual result:\n{manual_scores_1}")
print(f"Equal: {torch.allclose(attention_scores[0, 0, :, :], manual_scores_1)}")

Attention scores shape: torch.Size([1, 2, 3, 3])
Expected: [batch_size, num_heads, num_tokens, num_tokens]
Actual: [1, 2, 3, 3]

Attention scores:
tensor([[[[-0.0802,  0.0199,  0.7182],
          [-0.2299, -0.3512, -0.2293],
          [-0.1020, -0.2557,  0.0095]],

         [[ 0.3275,  0.1947, -0.6730],
          [-0.1775, -0.9021,  0.2122],
          [-0.1211,  0.9741,  0.5378]]]], grad_fn=<UnsafeViewBackward0>)

Head 1 attention scores:
tensor([[-0.0802,  0.0199,  0.7182],
        [-0.2299, -0.3512, -0.2293],
        [-0.1020, -0.2557,  0.0095]], grad_fn=<SliceBackward0>)

Head 2 attention scores:
tensor([[ 0.3275,  0.1947, -0.6730],
        [-0.1775, -0.9021,  0.2122],
        [-0.1211,  0.9741,  0.5378]], grad_fn=<SliceBackward0>)

✓ Attention scores have correct dimensions

Manual Head 1 computation verification:
PyTorch result:
tensor([[-0.0802,  0.0199,  0.7182],
        [-0.2299, -0.3512, -0.2293],
        [-0.1020, -0.2557,  0.0095]], grad_fn=<SliceBackward0>)
Manual result:
t

## Step 6: Apply Scaling, Masking, and Softmax

### The Processing Pipeline
1. **Scale** by √(head_dim) to prevent exploding values
2. **Apply causal mask** to prevent looking into the future
3. **Apply softmax** to normalize to probabilities
4. **Optional dropout** for regularization

### Scaling Rationale
**Problem**: As head_dim increases, dot products become larger
**Solution**: Divide by √(head_dim) to maintain variance ≈ 1

```python
scaled_scores = attention_scores / math.sqrt(head_dim)
```

### Causal Masking Implementation
**Objective**: Set upper triangular elements to -∞
**Result**: After softmax, these become 0 (can't attend to future)

```
Original:     After masking:
[a  b  c]     [a  -∞  -∞]
[d  e  f] →   [d   e  -∞]
[g  h  i]     [g   h   i]
```

### Softmax Properties
- **Row normalization**: Each row sums to 1
- **-∞ handling**: e^(-∞) = 0
- **Probability interpretation**: Attention weights as probabilities

### Mathematical Sequence
```
scores → scale → mask → softmax → weights
```

In [6]:
# Step 6: Apply scaling, causal masking, and softmax

# 6a. Scale by square root of head dimension
scaled_scores = attention_scores / math.sqrt(head_dim)
print(f"Scaled scores shape: {scaled_scores.shape}")
print(f"Scaling factor: 1/√{head_dim} = {1/math.sqrt(head_dim):.4f}")

# 6b. Create causal mask (upper triangular = True)
causal_mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1).bool()
print(f"\nCausal mask:\n{causal_mask}")
print(f"True = mask (set to -∞), False = keep")

# 6c. Apply causal mask
masked_scores = scaled_scores.masked_fill(causal_mask, float('-inf'))
print(f"\nMasked scores shape: {masked_scores.shape}")

print(f"\nHead 1 - Before masking:")
print(f"{scaled_scores[0, 0, :, :]}")
print(f"\nHead 1 - After masking:")
print(f"{masked_scores[0, 0, :, :]}")

print(f"\nHead 2 - Before masking:")
print(f"{scaled_scores[0, 1, :, :]}")
print(f"\nHead 2 - After masking:")
print(f"{masked_scores[0, 1, :, :]}")

# 6d. Apply softmax to get attention weights
attention_weights = F.softmax(masked_scores, dim=-1)
print(f"\nAttention weights shape: {attention_weights.shape}")

print(f"\nHead 1 attention weights:")
print(f"{attention_weights[0, 0, :, :]}")
print(f"Row sums: {attention_weights[0, 0, :, :].sum(dim=-1)}")

print(f"\nHead 2 attention weights:")
print(f"{attention_weights[0, 1, :, :]}")
print(f"Row sums: {attention_weights[0, 1, :, :].sum(dim=-1)}")

# Verify properties
assert torch.allclose(attention_weights[0, 0, :, :].sum(dim=-1), torch.ones(num_tokens))
assert torch.allclose(attention_weights[0, 1, :, :].sum(dim=-1), torch.ones(num_tokens))
print(f"\n✓ All rows sum to 1 (proper probability distributions)")

Scaled scores shape: torch.Size([1, 2, 3, 3])
Scaling factor: 1/√3 = 0.5774

Causal mask:
tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
True = mask (set to -∞), False = keep

Masked scores shape: torch.Size([1, 2, 3, 3])

Head 1 - Before masking:
tensor([[-0.0463,  0.0115,  0.4146],
        [-0.1327, -0.2028, -0.1324],
        [-0.0589, -0.1476,  0.0055]], grad_fn=<SliceBackward0>)

Head 1 - After masking:
tensor([[-0.0463,    -inf,    -inf],
        [-0.1327, -0.2028,    -inf],
        [-0.0589, -0.1476,  0.0055]], grad_fn=<SliceBackward0>)

Head 2 - Before masking:
tensor([[ 0.1891,  0.1124, -0.3885],
        [-0.1025, -0.5208,  0.1225],
        [-0.0699,  0.5624,  0.3105]], grad_fn=<SliceBackward0>)

Head 2 - After masking:
tensor([[ 0.1891,    -inf,    -inf],
        [-0.1025, -0.5208,    -inf],
        [-0.0699,  0.5624,  0.3105]], grad_fn=<SliceBackward0>)

Attention weights shape: torch.Size([1, 2, 3, 3])

Head 1 attention weights:

## Step 7: Compute Context Vectors for Each Head

### The Final Transformation
**Objective**: Convert attention weights into context vectors using value matrices

**Mathematical operation**:
```python
# For each head: attention_weights @ values
context_vectors = attention_weights @ V_heads
```

### Dimensional Analysis
**Input dimensions**:
- attention_weights: `[B, H, T, T]` = `[1, 2, 3, 3]`
- V_heads: `[B, H, T, head_dim]` = `[1, 2, 3, 3]`

**Matrix multiplication**:
- `[B, H, T, T] @ [B, H, T, head_dim]`
- Result: `[B, H, T, head_dim]` = `[1, 2, 3, 3]`

### Interpretation
**Context vector meaning**:
- Each token gets a context vector per head
- Context vector = weighted combination of all value vectors
- Weights determined by attention scores
- Different heads → different perspectives

### Head-Specific Context
- **Head 1 context**: V_heads weighted by attention_weights from Head 1
- **Head 2 context**: V_heads weighted by attention_weights from Head 2
- **Each head captures**: Different aspect of token relationships

In [7]:
# Step 7: Compute context vectors for each head
# attention_weights @ V_heads

context_vectors = attention_weights @ V_heads
print(f"Context vectors shape: {context_vectors.shape}")
print(f"Expected: [batch_size, num_heads, num_tokens, head_dim]")

print(f"\nContext vectors:\n{context_vectors}")

print(f"\nHead 1 context vectors:")
print(f"{context_vectors[0, 0, :, :]}")
print(f"Shape: {context_vectors[0, 0, :, :].shape}")

print(f"\nHead 2 context vectors:")
print(f"{context_vectors[0, 1, :, :]}")
print(f"Shape: {context_vectors[0, 1, :, :].shape}")

# Interpret the results
print(f"\nInterpretation:")
print(f"- Token 1, Head 1 context: {context_vectors[0, 0, 0, :]}")
print(f"- Token 1, Head 2 context: {context_vectors[0, 1, 0, :]}")
print(f"- Token 2, Head 1 context: {context_vectors[0, 0, 1, :]}")
print(f"- Token 2, Head 2 context: {context_vectors[0, 1, 1, :]}")

# Verify dimensions
assert context_vectors.shape == (batch_size, num_heads, num_tokens, head_dim)
print(f"\n✓ Context vectors have correct dimensions")

# Manual verification for one computation
print(f"\nManual verification for Token 1, Head 1:")
manual_context = attention_weights[0, 0, 0, :] @ V_heads[0, 0, :, :]
print(f"Manual computation: {manual_context}")
print(f"PyTorch result: {context_vectors[0, 0, 0, :]}")
print(f"Equal: {torch.allclose(manual_context, context_vectors[0, 0, 0, :])}")

Context vectors shape: torch.Size([1, 2, 3, 3])
Expected: [batch_size, num_heads, num_tokens, head_dim]

Context vectors:
tensor([[[[-0.8460,  0.2317,  0.0061],
          [ 0.2524, -0.1025,  0.5735],
          [ 0.0355,  0.4801, -0.2450]],

         [[-0.1790,  0.0405,  0.0707],
          [ 0.3812,  0.0439, -0.1098],
          [ 0.3663,  0.3066,  0.0614]]]], grad_fn=<UnsafeViewBackward0>)

Head 1 context vectors:
tensor([[-0.8460,  0.2317,  0.0061],
        [ 0.2524, -0.1025,  0.5735],
        [ 0.0355,  0.4801, -0.2450]], grad_fn=<SliceBackward0>)
Shape: torch.Size([3, 3])

Head 2 context vectors:
tensor([[-0.1790,  0.0405,  0.0707],
        [ 0.3812,  0.0439, -0.1098],
        [ 0.3663,  0.3066,  0.0614]], grad_fn=<SliceBackward0>)
Shape: torch.Size([3, 3])

Interpretation:
- Token 1, Head 1 context: tensor([-0.8460,  0.2317,  0.0061], grad_fn=<SliceBackward0>)
- Token 1, Head 2 context: tensor([-0.1790,  0.0405,  0.0707], grad_fn=<SliceBackward0>)
- Token 2, Head 1 context: tensor([

## Step 8: Merge Heads (Concatenation)

### The Merging Process
**Objective**: Combine all head outputs into single context matrix

**Current structure**: `[B, H, T, head_dim]` - grouped by heads
**Target structure**: `[B, T, d_out]` - grouped by tokens

### Two-Step Process

#### Step 8a: Transpose Back to Token Grouping
```python
# [B, H, T, head_dim] → [B, T, H, head_dim]
context_vectors.transpose(1, 2)
```

#### Step 8b: Flatten Head Dimensions
```python
# [B, T, H, head_dim] → [B, T, H * head_dim]
# [B, T, 2, 3] → [B, T, 6]
.contiguous().view(B, T, d_out)
```

### Visual Understanding

**Before merging** (grouped by heads):
```
Head 1:
  Token 1: [c₁₁, c₁₂, c₁₃]
  Token 2: [c₂₁, c₂₂, c₂₃]
  Token 3: [c₃₁, c₃₂, c₃₃]

Head 2:
  Token 1: [c₁₄, c₁₅, c₁₆]
  Token 2: [c₂₄, c₂₅, c₂₆]
  Token 3: [c₃₄, c₃₅, c₃₆]
```

**After merging** (grouped by tokens):
```
Token 1: [c₁₁, c₁₂, c₁₃, c₁₄, c₁₅, c₁₆]  # Head 1 + Head 2
Token 2: [c₂₁, c₂₂, c₂₃, c₂₄, c₂₅, c₂₆]  # Head 1 + Head 2
Token 3: [c₃₁, c₃₂, c₃₃, c₃₄, c₃₅, c₃₆]  # Head 1 + Head 2
```

### Final Result Properties
- **Same dimensions** as single-head attention output
- **Richer content**: Multiple perspectives embedded
- **Seamless integration**: Can replace single-head attention directly

In [8]:
# Step 8: Merge heads by concatenating head outputs

# 8a. Transpose back to group by tokens
# [B, H, T, head_dim] → [B, T, H, head_dim]
context_transposed = context_vectors.transpose(1, 2)
print(f"After transpose back to token grouping:")
print(f"Shape: {context_transposed.shape}")
print(f"Expected: [batch_size, num_tokens, num_heads, head_dim]")

print(f"\nContext transposed:\n{context_transposed}")

print(f"\nToken-wise view:")
print(f"Token 1: Head 1 = {context_transposed[0, 0, 0, :]}, Head 2 = {context_transposed[0, 0, 1, :]}")
print(f"Token 2: Head 1 = {context_transposed[0, 1, 0, :]}, Head 2 = {context_transposed[0, 1, 1, :]}")
print(f"Token 3: Head 1 = {context_transposed[0, 2, 0, :]}, Head 2 = {context_transposed[0, 2, 1, :]}")

# 8b. Flatten the head dimensions to concatenate
# [B, T, H, head_dim] → [B, T, H * head_dim] = [B, T, d_out]
final_context = context_transposed.contiguous().view(batch_size, num_tokens, d_out)

print(f"\nFinal context matrix:")
print(f"Shape: {final_context.shape}")
print(f"Expected: [batch_size, num_tokens, d_out] = [1, 3, 6]")

print(f"\nFinal context:\n{final_context}")

print(f"\nFinal context breakdown:")
print(f"Token 1 final context: {final_context[0, 0, :]} (Head1: {final_context[0, 0, :3]}, Head2: {final_context[0, 0, 3:]})")
print(f"Token 2 final context: {final_context[0, 1, :]} (Head1: {final_context[0, 1, :3]}, Head2: {final_context[0, 1, 3:]})")
print(f"Token 3 final context: {final_context[0, 2, :]} (Head1: {final_context[0, 2, :3]}, Head2: {final_context[0, 2, 3:]})")

# Verify final dimensions
assert final_context.shape == (batch_size, num_tokens, d_out)
print(f"\n✓ Final context has correct dimensions: {final_context.shape}")

# Compare to single-head output dimensions
print(f"\nDimension comparison:")
print(f"Single-head output would be: [1, 3, 6]")
print(f"Multi-head output is:        {list(final_context.shape)}")
print(f"Same dimensions: {final_context.shape == (1, 3, 6)}")
print(f"But multi-head contains multiple perspectives!")

After transpose back to token grouping:
Shape: torch.Size([1, 3, 2, 3])
Expected: [batch_size, num_tokens, num_heads, head_dim]

Context transposed:
tensor([[[[-0.8460,  0.2317,  0.0061],
          [-0.1790,  0.0405,  0.0707]],

         [[ 0.2524, -0.1025,  0.5735],
          [ 0.3812,  0.0439, -0.1098]],

         [[ 0.0355,  0.4801, -0.2450],
          [ 0.3663,  0.3066,  0.0614]]]], grad_fn=<TransposeBackward0>)

Token-wise view:
Token 1: Head 1 = tensor([-0.8460,  0.2317,  0.0061], grad_fn=<SliceBackward0>), Head 2 = tensor([-0.1790,  0.0405,  0.0707], grad_fn=<SliceBackward0>)
Token 2: Head 1 = tensor([ 0.2524, -0.1025,  0.5735], grad_fn=<SliceBackward0>), Head 2 = tensor([ 0.3812,  0.0439, -0.1098], grad_fn=<SliceBackward0>)
Token 3: Head 1 = tensor([ 0.0355,  0.4801, -0.2450], grad_fn=<SliceBackward0>), Head 2 = tensor([0.3663, 0.3066, 0.0614], grad_fn=<SliceBackward0>)

Final context matrix:
Shape: torch.Size([1, 3, 6])
Expected: [batch_size, num_tokens, d_out] = [1, 3, 6]

Fi

## Complete Multi-Head Attention Implementation

### Full Class Implementation
Now let's implement the complete multi-head attention class that encapsulates all the steps we've seen:

```python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, dropout=0.0):
        super().__init__()
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        
        # Linear layers for Q, K, V transformations
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
        # Register causal mask as buffer
        self.register_buffer(
            'mask', 
            torch.triu(torch.ones(1000, 1000), diagonal=1)
        )
    
    def forward(self, x):
        # Implementation of all 8 steps
        ...
```

### Key Implementation Details
1. **Efficient mask handling**: Pre-registered buffer
2. **Flexible sequence length**: Mask larger than needed
3. **Dropout integration**: Optional regularization
4. **Batch processing**: Handles multiple batches efficiently
5. **Memory efficiency**: In-place operations where possible

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length=1000, dropout=0.0):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        
        # Linear transformations for Q, K, V
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # Register causal mask buffer
        self.register_buffer('mask', 
                            torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
    def forward(self, x):
        # Input shape: [batch_size, num_tokens, d_in]
        batch_size, num_tokens, d_in = x.shape
        
        # Step 1: Generate Q, K, V matrices
        Q = self.W_query(x)  # [batch_size, num_tokens, d_out]
        K = self.W_key(x)    # [batch_size, num_tokens, d_out]
        V = self.W_value(x)  # [batch_size, num_tokens, d_out]
        
        # Step 2: Reshape for multiple heads
        Q = Q.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        K = K.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        V = V.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        
        # Step 3: Transpose to group by heads
        Q = Q.transpose(1, 2)  # [batch_size, num_heads, num_tokens, head_dim]
        K = K.transpose(1, 2)  # [batch_size, num_heads, num_tokens, head_dim]
        V = V.transpose(1, 2)  # [batch_size, num_heads, num_tokens, head_dim]
        
        # Step 4: Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1)  # [batch_size, num_heads, num_tokens, num_tokens]
        
        # Step 5: Scale scores
        attention_scores = attention_scores / math.sqrt(self.head_dim)
        
        # Step 6: Apply causal mask
        mask = self.mask[:num_tokens, :num_tokens]
        attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Step 7: Apply softmax and dropout
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Step 8: Compute context vectors
        context_vectors = attention_weights @ V  # [batch_size, num_heads, num_tokens, head_dim]
        
        # Step 9: Transpose back and merge heads
        context_vectors = context_vectors.transpose(1, 2)  # [batch_size, num_tokens, num_heads, head_dim]
        context_vectors = context_vectors.contiguous().view(batch_size, num_tokens, self.d_out)
        
        return context_vectors

# Test the implementation
mha = MultiHeadAttention(d_in=6, d_out=6, num_heads=2, dropout=0.0)

# Test with our example input
test_input = torch.randn(1, 3, 6)
output = mha(test_input)

print(f"MultiHeadAttention test:")
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Output:\n{output}")

# Test with multiple batches
batch_input = torch.randn(2, 3, 6)
batch_output = mha(batch_input)
print(f"\nBatch test:")
print(f"Batch input shape: {batch_input.shape}")
print(f"Batch output shape: {batch_output.shape}")

# Verify our manual computation matches the class
manual_output = final_context
class_output = mha(X)
print(f"\nVerification (manual vs class):")
print(f"Manual output shape: {manual_output.shape}")
print(f"Class output shape: {class_output.shape}")
print(f"Shapes match: {manual_output.shape == class_output.shape}")

RuntimeError: masked_fill_ only supports boolean masks, but got mask with dtype float

## Performance Analysis and Key Insights

### Computational Complexity Analysis

#### Parameter Count Comparison
**Single-head attention**:
```
WQ: d_in × d_out = 6 × 6 = 36 parameters
WK: d_in × d_out = 6 × 6 = 36 parameters  
WV: d_in × d_out = 6 × 6 = 36 parameters
Total: 108 parameters
```

**Multi-head attention (2 heads)**:
```
Same weight matrices, just reshaped!
Total: 108 parameters (no increase)
```

#### Memory Complexity
**Attention matrices per head**: O(seq_len²)
- Single head: 1 × [3 × 3] = 9 elements
- Multi-head: 2 × [3 × 3] = 18 elements
- **Memory scales linearly with number of heads**

#### Time Complexity
**Per head**: O(seq_len² × head_dim)
**Total**: O(seq_len² × d_out) - same as single head!

### Why Multi-Head Attention Works

#### 1. **No Parameter Overhead**
- Same parameter count as single-head
- Just reorganized differently
- Pure architectural innovation

#### 2. **Parallel Perspective Learning**
- Each head specializes in different patterns
- Head 1: Syntactic relationships
- Head 2: Semantic relationships  
- Head N: Positional patterns

#### 3. **Empirical Superiority**
- Consistently outperforms single-head
- Captures richer representations
- Better generalization across tasks

#### 4. **Scalability**
- Easily scales to many heads (8, 12, 16+)
- GPU-friendly parallel computation
- Modular design for different architectures

In [10]:
# Performance analysis and comparisons

def analyze_multihead_performance():
    """Analyze the performance characteristics of multi-head attention"""
    
    # Parameter count analysis
    d_in, d_out = 512, 512  # Typical transformer dimensions
    num_heads_options = [1, 8, 12, 16]
    
    print("Parameter Count Analysis:")
    print("=" * 50)
    
    base_params = 3 * d_in * d_out  # WQ + WK + WV
    print(f"Base parameters (WQ + WK + WV): {base_params:,}")
    
    for num_heads in num_heads_options:
        head_dim = d_out // num_heads
        # Parameters remain the same regardless of heads!
        print(f"Heads: {num_heads:2d}, Head dim: {head_dim:3d}, Parameters: {base_params:,}")
    
    print("\nMemory Analysis (for seq_len=1024):")
    print("=" * 50)
    
    seq_len = 1024
    for num_heads in num_heads_options:
        attention_matrices = num_heads * seq_len * seq_len
        memory_mb = attention_matrices * 4 / (1024 * 1024)  # 4 bytes per float32
        print(f"Heads: {num_heads:2d}, Attention matrices: {attention_matrices:,} elements, Memory: {memory_mb:.1f} MB")
    
    print("\nComputational Complexity Analysis:")
    print("=" * 50)
    
    for num_heads in num_heads_options:
        head_dim = d_out // num_heads
        ops_per_head = seq_len * seq_len * head_dim
        total_ops = num_heads * ops_per_head
        # Note: total_ops = seq_len² × d_out regardless of num_heads!
        print(f"Heads: {num_heads:2d}, Ops per head: {ops_per_head:,}, Total ops: {total_ops:,}")

analyze_multihead_performance()

# Demonstrate head specialization potential
def demonstrate_head_specialization():
    """Show how different heads can capture different patterns"""
    
    print("\n\nHead Specialization Demonstration:")
    print("=" * 50)
    
    # Create a model with many heads to show specialization potential
    specialized_mha = MultiHeadAttention(d_in=128, d_out=128, num_heads=8)
    
    # Example sentences that could benefit from different perspectives
    sentences = [
        "The artist painted the portrait",
        "The government should regulate speech", 
        "Time flies like an arrow",
        "Bank accounts and river banks"
    ]
    
    print("Multi-head attention enables models to capture:")
    print("• Head 1: Syntactic relationships (subject-verb-object)")
    print("• Head 2: Semantic relationships (word meanings)")
    print("• Head 3: Positional patterns (sequence order)")
    print("• Head 4: Long-range dependencies")
    print("• Head 5: Entity relationships")
    print("• Head 6: Discourse structure")
    print("• Head 7: Attention to rare words")
    print("• Head 8: Contextual disambiguation")
    
    print(f"\nEach sentence can be analyzed from {specialized_mha.num_heads} different perspectives!")
    print("This is why modern LLMs use 12-96 attention heads per layer.")

demonstrate_head_specialization()

Parameter Count Analysis:
Base parameters (WQ + WK + WV): 786,432
Heads:  1, Head dim: 512, Parameters: 786,432
Heads:  8, Head dim:  64, Parameters: 786,432
Heads: 12, Head dim:  42, Parameters: 786,432
Heads: 16, Head dim:  32, Parameters: 786,432

Memory Analysis (for seq_len=1024):
Heads:  1, Attention matrices: 1,048,576 elements, Memory: 4.0 MB
Heads:  8, Attention matrices: 8,388,608 elements, Memory: 32.0 MB
Heads: 12, Attention matrices: 12,582,912 elements, Memory: 48.0 MB
Heads: 16, Attention matrices: 16,777,216 elements, Memory: 64.0 MB

Computational Complexity Analysis:
Heads:  1, Ops per head: 536,870,912, Total ops: 536,870,912
Heads:  8, Ops per head: 67,108,864, Total ops: 536,870,912
Heads: 12, Ops per head: 44,040,192, Total ops: 528,482,304
Heads: 16, Ops per head: 33,554,432, Total ops: 536,870,912


Head Specialization Demonstration:
Multi-head attention enables models to capture:
• Head 1: Syntactic relationships (subject-verb-object)
• Head 2: Semantic relatio

## Connection to Modern Language Models

### Multi-Head Attention in Practice

#### GPT Models
- **GPT-1**: 12 heads per layer, 12 layers
- **GPT-2**: 12-25 heads per layer, 24-48 layers  
- **GPT-3**: 96 heads per layer, 96 layers
- **GPT-4**: Estimated 128+ heads per layer

#### BERT Models
- **BERT-Base**: 12 heads per layer, 12 layers
- **BERT-Large**: 16 heads per layer, 24 layers

#### Other Architectures
- **T5**: 12-32 heads depending on size
- **PaLM**: 64 heads per layer
- **DeepSeek**: Advanced multi-head latent attention

### Research Insights on Head Specialization

#### Empirical Findings
1. **Syntactic heads**: Learn grammatical relationships
2. **Semantic heads**: Focus on word meanings and concepts
3. **Positional heads**: Track sequence order and position
4. **Rare word heads**: Pay attention to uncommon tokens
5. **Long-range heads**: Capture distant dependencies

#### Head Pruning Studies
- Many heads can be removed without significant performance loss
- Some heads are more critical than others
- Redundancy provides robustness

### The Path to Advanced Architectures

#### Current Position in Learning Journey
1. ✅ **Self-attention**: Foundation mechanism
2. ✅ **Causal attention**: Prevents future information leakage  
3. ✅ **Multi-head attention**: Multiple perspectives
4. 🔄 **Next: Key-Value caching** - Efficiency optimization
5. 🔄 **Final: Multi-head latent attention** - DeepSeek's innovation

#### Why This Foundation Matters
- **KV caching**: Optimizes the K,V computations we just learned
- **Flash attention**: Memory-efficient attention computation
- **Multi-head latent attention**: Compresses multi-head structure
- **Understanding prerequisites**: Advanced concepts build on these basics

## Key Takeaways and Summary

### Core Concepts Mastered

#### 1. **Mathematical Implementation**
- **8-step process**: From input to final context vectors
- **Dimensional transformations**: Tracking shapes through all operations
- **Matrix operations**: Understanding every multiplication and transpose

#### 2. **Multi-Head Mechanism**
- **Parameter efficiency**: No increase in parameter count
- **Perspective diversity**: Each head captures different relationships
- **Parallel processing**: All heads computed simultaneously

#### 3. **Implementation Details**
- **Reshaping operations**: view() for dimension changes
- **Transpose operations**: Grouping by heads vs tokens
- **Masking and normalization**: Causal attention integration

### Critical Insights

#### **The Magic of Multi-Head Attention**
```
Same parameters + Different organization = Multiple perspectives
```

#### **Dimensional Consistency**
```
Input:  [B, T, d_in]  → Processing → Output: [B, T, d_out]
Single-head and multi-head have identical input/output shapes!
```

#### **Computational Efficiency**
```
Time complexity: O(seq_len² × d_out) - independent of number of heads
Memory complexity: O(num_heads × seq_len²) - linear scaling
```

### Why This Matters for DeepSeek

#### **Foundation for Advanced Concepts**
- Multi-head attention is the baseline that DeepSeek improves upon
- Key-Value caching optimizes the computation we just learned
- Multi-head latent attention compresses the multi-head structure

#### **Understanding Prerequisites**
- Can't understand advanced optimizations without mastering the basics
- Every modern LLM builds on these fundamental mechanisms
- Critical for implementing and modifying transformer architectures

### Next Steps

#### **Immediate Applications**
- Implement transformer blocks using this multi-head attention
- Experiment with different numbers of heads
- Analyze attention patterns in pre-trained models

#### **Advanced Topics Coming**
- **Key-Value caching**: Memory and speed optimizations
- **Flash attention**: Memory-efficient computation
- **Multi-head latent attention**: DeepSeek's key innovation

This multi-head attention implementation is the workhorse of modern AI - understanding it deeply is essential for anyone working with large language models.