# Topic 7: Attention Mechanisms - A Comprehensive Guide

## Learning Objectives

By the end of this notebook, you will:
- Understand **why** attention was invented and what problem it solves
- Learn the intuition behind Queries, Keys, and Values
- Build self-attention from scratch, step by step
- Understand cross-attention and when to use it
- Implement multi-head attention from first principles
- Use PyTorch's optimized attention functions
- Connect attention to transformers and modern LLMs

## The Big Picture: Why Attention?

### The Problem: Fixed-Length Bottleneck

Before attention, sequence-to-sequence models (like machine translation) used an **encoder-decoder** architecture:

```
Input Sequence → Encoder → Fixed-Size Vector → Decoder → Output Sequence
```

**The critical flaw**: All input information must compress into a single fixed-size vector!

**Real-world example** (English to French translation):
```
Input: "The cat sat on the mat because it was comfortable"
Problem: When translating "it", the decoder has lost the context
         that "it" refers to "the mat", not "the cat"
```

**Why this fails**:
- **Information loss**: Long sequences can't fit in fixed vector
- **No selectivity**: Decoder can't focus on relevant parts of input
- **Gradient problems**: Backprop through long sequences is unstable

### The Attention Solution

**Key insight**: Instead of a fixed vector, let the decoder **attend to the entire input sequence** and dynamically focus on relevant parts!

```
When translating "it":
↓
Look at all input words with learned weights:
  "The"    (0.02)
  "cat"    (0.05)
  "sat"    (0.03)
  "on"     (0.04)
  "the"    (0.08)
  "mat"    (0.65)  ← High attention!
  "because"(0.05)
  "it"     (0.03)
  "was"    (0.03)
  "comfortable" (0.02)
```

**Impact**: Attention revolutionized NLP and enabled:
- Transformers (GPT, BERT, LLaMA)
- Vision transformers (ViT)
- Multimodal models (CLIP, Flamingo)
- State-of-the-art results across all of AI

**Why it cannot be skipped**: Attention is the foundation of modern deep learning. Understanding it deeply is essential for working with any state-of-the-art model.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

# Set up visualization style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 4)

## Understanding Queries, Keys, and Values

### The Search Engine Analogy

Think of attention like a search engine:

**1. Query (Q)**: What you're searching for
- "Best Italian restaurants in Manhattan"
- Represents what information you need

**2. Key (K)**: Indexed titles/tags of documents
- "Italian cuisine in NYC", "Manhattan dining guide", "Best pizza in Brooklyn"
- Represents what information is available

**3. Value (V)**: The actual document content
- Full restaurant reviews, menus, addresses
- Represents the information to retrieve

**The attention process**:
1. **Compare Query with Keys**: Calculate relevance scores (which documents match your search?)
2. **Softmax**: Convert scores to probabilities (how much to focus on each document?)
3. **Weighted sum of Values**: Get a blend of most relevant content

### Why This Design?

**Separation of concerns**:
- **Keys**: What can be attended to (searchable index)
- **Values**: What information is retrieved (actual content)
- **Queries**: What is being looked for (current context)

**Flexibility**:
- Same Keys/Values, different Queries → different focus
- Learned projections allow the model to transform inputs into optimal Q, K, V representations

**Why it's needed**: Without separate Q, K, V, the model can't learn what to search for vs. what to retrieve. This separation is crucial for flexible, context-dependent information flow.

## Self-Attention: Step-by-Step from Scratch

### What is Self-Attention?

**Self-attention** means the input sequence attends to itself:
- Queries, Keys, and Values all come from the same source
- Each position looks at all positions (including itself)
- Captures relationships within a single sequence

**Example**: In the sentence "The cat sat on the mat"
- When processing "sat", attend to "cat" (subject) and "mat" (location)
- When processing "it", attend to "cat" or "mat" (pronoun resolution)

### The Scaled Dot-Product Attention Formula

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

**Breaking it down**:
1. $QK^T$: Compute similarity between queries and keys (dot product)
2. $\frac{1}{\sqrt{d_k}}$: Scale by sqrt of dimension (prevents huge values)
3. $\text{softmax}$: Convert to probabilities (sums to 1)
4. Multiply by $V$: Weighted combination of values

**Why each step is needed**:
- **Dot product**: Measures alignment between vectors (high value = similar)
- **Scaling**: Without it, softmax saturates (all weight on one position)
- **Softmax**: Normalizes scores, allows gradient flow to all positions
- **Multiply V**: Actually retrieve and blend information

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention from scratch.
    
    Args:
        Q: Queries (batch_size, num_heads, seq_len, d_k)
        K: Keys (batch_size, num_heads, seq_len, d_k)
        V: Values (batch_size, num_heads, seq_len, d_v)
        mask: Optional mask (batch_size, 1, seq_len, seq_len)
    
    Returns:
        output: Attention output (batch_size, num_heads, seq_len, d_v)
        attention_weights: Attention probabilities (batch_size, num_heads, seq_len, seq_len)
    """
    # Step 1: Get dimension of keys (for scaling)
    d_k = K.size(-1)
    
    # Step 2: Compute attention scores (Q * K^T)
    # Why dot product? It measures similarity between query and key vectors
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, heads, seq_len, seq_len)
    
    # Step 3: Scale by sqrt(d_k)
    # Why? Prevents dot products from becoming too large (which makes softmax peak)
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # Step 4: Apply mask (if provided)
    # Why? To prevent attending to certain positions (e.g., future tokens in causal attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 5: Apply softmax to get attention probabilities
    # Why? Converts scores to probabilities that sum to 1
    attention_weights = F.softmax(scores, dim=-1)  # (batch, heads, seq_len, seq_len)
    
    # Step 6: Weighted sum of values
    # Why? This is where we actually retrieve and blend information
    output = torch.matmul(attention_weights, V)  # (batch, heads, seq_len, d_v)
    
    return output, attention_weights

print("Scaled dot-product attention function defined!")
print("\nThis is the CORE of all attention mechanisms.")
print("Every transformer, every LLM uses this exact computation.")

### Let's See Attention in Action: A Simple Example

We'll create a small sequence and visualize how attention weights are computed.

In [None]:
# Create a simple example: 5-token sequence with 8-dimensional embeddings
seq_len = 5
d_model = 8

# For this example, Q=K=V (self-attention)
X = torch.randn(1, 1, seq_len, d_model)  # (batch=1, heads=1, seq_len=5, d_model=8)

# In self-attention, Q, K, V are projections of the same input
# For simplicity, we'll just use X directly
Q = K = V = X

# Compute attention
output, attention_weights = scaled_dot_product_attention(Q, K, V)

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(attention_weights[0, 0].detach().numpy(), 
            annot=True, fmt='.3f', cmap='YlOrRd',
            xticklabels=[f'Token {i}' for i in range(seq_len)],
            yticklabels=[f'Token {i}' for i in range(seq_len)],
            cbar_kws={'label': 'Attention Weight'})
plt.title('Self-Attention Weights\n(Row i = where token i attends to)')
plt.xlabel('Key positions')
plt.ylabel('Query positions')
plt.tight_layout()
plt.show()

print("\nHow to read this heatmap:")
print("- Each row shows where one token attends")
print("- Brighter colors = higher attention weight")
print("- Each row sums to 1.0 (probability distribution)")
print("\nNotice:")
print("- Diagonal (self-attention) is often high")
print("- Each token can attend to all other tokens")
print("- Weights are learned during training!")

## Building a Complete Self-Attention Layer

### Why Do We Need Learned Projections?

In practice, we don't use input embeddings directly as Q, K, V. Instead:

**1. We project them through learned weight matrices**:
```python
Q = X @ W_Q  # Learn what to search for
K = X @ W_K  # Learn what to index
V = X @ W_V  # Learn what to retrieve
```

**Why this is crucial**:
- **Flexibility**: Different transformations for different purposes
- **Learning**: Model learns optimal representations for attention
- **Expressiveness**: Can attend to different aspects of inputs

**2. We add an output projection**:
```python
output = attention_output @ W_O
```

**Why it's needed**:
- Combines information from attention
- Projects back to model dimension
- Adds additional learnable transformation

In [None]:
class SelfAttention(nn.Module):
    """
    Self-attention layer with learned Q, K, V projections.
    
    This is the building block of transformers!
    """
    def __init__(self, d_model, d_k, d_v):
        """
        Args:
            d_model: Dimension of input embeddings
            d_k: Dimension of queries and keys
            d_v: Dimension of values
        """
        super(SelfAttention, self).__init__()
        
        self.d_k = d_k
        
        # Learned projection matrices
        # Why separate W_Q, W_K, W_V? Each has a different role!
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_v, bias=False)
        
        # Output projection
        # Why needed? Combines attention output back to d_model dimension
        self.W_O = nn.Linear(d_v, d_model, bias=False)
    
    def forward(self, X, mask=None):
        """
        Args:
            X: Input tensor (batch_size, seq_len, d_model)
            mask: Optional attention mask
        
        Returns:
            output: Attention output (batch_size, seq_len, d_model)
            attention_weights: Attention probabilities
        """
        # Step 1: Project to Q, K, V
        # This is where the model learns WHAT to attend to
        Q = self.W_Q(X)  # (batch, seq_len, d_k)
        K = self.W_K(X)  # (batch, seq_len, d_k)
        V = self.W_V(X)  # (batch, seq_len, d_v)
        
        # Add head dimension for compatibility with scaled_dot_product_attention
        Q = Q.unsqueeze(1)  # (batch, 1, seq_len, d_k)
        K = K.unsqueeze(1)
        V = V.unsqueeze(1)
        
        # Step 2: Compute attention
        attn_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Remove head dimension
        attn_output = attn_output.squeeze(1)  # (batch, seq_len, d_v)
        
        # Step 3: Project output back to d_model
        output = self.W_O(attn_output)  # (batch, seq_len, d_model)
        
        return output, attention_weights

# Create and test the self-attention layer
d_model = 64
d_k = d_v = 64
seq_len = 10
batch_size = 2

self_attn = SelfAttention(d_model, d_k, d_v)
X = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = self_attn(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print("\nNotice: Output shape matches input shape!")
print("This allows stacking attention layers in deep networks.")

## Cross-Attention: Attending Across Sequences

### What is Cross-Attention?

**Self-attention**: Q, K, V all from the same sequence
**Cross-attention**: Q from one sequence, K and V from another

**Use case**: Decoder attending to encoder outputs in sequence-to-sequence tasks

```
Machine Translation Example:
Encoder input:  "Le chat est noir"     (French)
Decoder:        "The cat is ___"       (English, being generated)

Cross-attention:
  Q: From decoder ("The cat is")
  K, V: From encoder ("Le chat est noir")
  
The decoder queries what it needs from the encoder!
```

**Why it's different from self-attention**:
- **Information flow**: One sequence queries information from another
- **Asymmetric**: Query sequence can be different length than Key/Value sequence
- **Purpose**: Allows decoder to focus on relevant encoder states

**Real-world applications**:
- **Translation**: Decoder attends to source language
- **Image captioning**: Text decoder attends to image features
- **CLIP**: Vision attends to text (and vice versa)
- **Multimodal LLMs**: Text attends to image patches

In [None]:
class CrossAttention(nn.Module):
    """
    Cross-attention: Attend from one sequence to another.
    
    Used in encoder-decoder architectures!
    """
    def __init__(self, d_model, d_k, d_v):
        super(CrossAttention, self).__init__()
        
        self.d_k = d_k
        
        # Q from decoder (what we're looking for)
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        
        # K, V from encoder (what information is available)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_v, bias=False)
        
        self.W_O = nn.Linear(d_v, d_model, bias=False)
    
    def forward(self, decoder_hidden, encoder_outputs, mask=None):
        """
        Args:
            decoder_hidden: Decoder states (batch, decoder_len, d_model)
            encoder_outputs: Encoder states (batch, encoder_len, d_model)
            mask: Optional mask (batch, 1, decoder_len, encoder_len)
        
        Returns:
            output: Cross-attention output (batch, decoder_len, d_model)
            attention_weights: Where decoder attended in encoder
        """
        # Q from decoder (what does decoder need?)
        Q = self.W_Q(decoder_hidden).unsqueeze(1)  # (batch, 1, decoder_len, d_k)
        
        # K, V from encoder (what information is available?)
        K = self.W_K(encoder_outputs).unsqueeze(1)  # (batch, 1, encoder_len, d_k)
        V = self.W_V(encoder_outputs).unsqueeze(1)  # (batch, 1, encoder_len, d_v)
        
        # Compute cross-attention
        # This is where decoder "looks at" encoder!
        attn_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        attn_output = attn_output.squeeze(1)
        output = self.W_O(attn_output)
        
        return output, attention_weights

# Example: Translation scenario
encoder_len = 7  # "Le chat est noir" (4 words + special tokens)
decoder_len = 4  # "The cat is" (3 words generated so far)
d_model = 64

cross_attn = CrossAttention(d_model, d_k=64, d_v=64)

# Encoder has processed the French sentence
encoder_outputs = torch.randn(1, encoder_len, d_model)

# Decoder is generating English
decoder_hidden = torch.randn(1, decoder_len, d_model)

# Cross-attention: Decoder attends to encoder
output, attn_weights = cross_attn(decoder_hidden, encoder_outputs)

print(f"Encoder sequence length: {encoder_len}")
print(f"Decoder sequence length: {decoder_len}")
print(f"Cross-attention output: {output.shape}")
print(f"Attention weights: {attn_weights.shape}")

# Visualize cross-attention
plt.figure(figsize=(10, 5))
sns.heatmap(attn_weights[0, 0].detach().numpy(),
            annot=True, fmt='.3f', cmap='Blues',
            xticklabels=[f'Enc{i}' for i in range(encoder_len)],
            yticklabels=[f'Dec{i}' for i in range(decoder_len)],
            cbar_kws={'label': 'Attention Weight'})
plt.title('Cross-Attention Weights\n(Decoder attending to Encoder)')
plt.xlabel('Encoder positions (Keys/Values)')
plt.ylabel('Decoder positions (Queries)')
plt.tight_layout()
plt.show()

print("\nNotice: Decoder positions attend to ENCODER positions!")
print("This is fundamentally different from self-attention.")

## Multi-Head Attention: Learning Different Relationships

### Why Multiple Heads?

**Problem with single-head attention**: One attention pattern might not capture all relationships!

**Example sentence**: "The bank by the river is steep"

Different types of relationships:
- **Head 1**: Syntactic relationships ("bank" → "is")
- **Head 2**: Semantic relationships ("bank" → "river")
- **Head 3**: Long-range dependencies ("steep" → "bank")

**Multi-head attention**: Run attention multiple times in parallel, each with different learned projections

```python
# Instead of one attention:
output = Attention(Q, K, V)

# Use multiple heads:
head_1 = Attention(Q1, K1, V1)
head_2 = Attention(Q2, K2, V2)
...
head_h = Attention(Qh, Kh, Vh)

output = Concat(head_1, ..., head_h) @ W_O
```

**Why this works**:
1. **Diversity**: Each head can specialize in different patterns
2. **Ensemble effect**: Multiple perspectives are better than one
3. **Efficiency**: Splitting dimensions doesn't increase compute much
4. **Representation power**: Can represent complex attention patterns

**Standard configuration**:
- **8 heads** (small models) to **64 heads** (large LLMs)
- Each head has dimension $d_k = d_{model} / h$ (split the dimensions)
- Total parameters similar to single-head with full dimension

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention: The secret sauce of transformers!
    
    Key insight: Multiple attention patterns in parallel.
    """
    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model: Model dimension (e.g., 512)
            num_heads: Number of attention heads (e.g., 8)
        
        Note: d_model must be divisible by num_heads!
        Why? We split d_model across heads: d_k = d_model / num_heads
        """
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head
        
        # Projections for all heads (done in one matrix multiply for efficiency)
        # Why single matrices? More efficient than separate matrices per head
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        
        # Output projection
        self.W_O = nn.Linear(d_model, d_model, bias=False)
    
    def split_heads(self, x):
        """
        Split the last dimension into (num_heads, d_k).
        
        Why? Each head processes a different subspace of the embeddings.
        
        Input:  (batch, seq_len, d_model)
        Output: (batch, num_heads, seq_len, d_k)
        """
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """
        Inverse of split_heads: merge heads back together.
        
        Input:  (batch, num_heads, seq_len, d_k)
        Output: (batch, seq_len, d_model)
        """
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, X, mask=None):
        """
        Args:
            X: Input (batch, seq_len, d_model)
            mask: Optional mask (batch, 1, seq_len, seq_len)
        
        Returns:
            output: (batch, seq_len, d_model)
            attention_weights: (batch, num_heads, seq_len, seq_len)
        """
        # Step 1: Linear projections for all heads at once
        Q = self.W_Q(X)  # (batch, seq_len, d_model)
        K = self.W_K(X)
        V = self.W_V(X)
        
        # Step 2: Split into multiple heads
        # This is where we create parallel attention patterns!
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Step 3: Scaled dot-product attention for all heads in parallel
        # Each head learns different attention patterns!
        attn_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Step 4: Concatenate heads
        attn_output = self.combine_heads(attn_output)  # (batch, seq_len, d_model)
        
        # Step 5: Final linear projection
        # Why? Combine information from all heads
        output = self.W_O(attn_output)
        
        return output, attention_weights

# Create multi-head attention
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads)
X = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = mha(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nConfiguration:")
print(f"  d_model: {d_model}")
print(f"  num_heads: {num_heads}")
print(f"  d_k per head: {d_model // num_heads}")
print(f"\nEach of the {num_heads} heads learned different attention patterns!")

### Visualizing Multi-Head Attention

Let's see how different heads learn different patterns.

In [None]:
# Visualize attention patterns from different heads
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for head_idx in range(num_heads):
    ax = axes[head_idx]
    
    # Get attention weights for this head
    head_weights = attn_weights[0, head_idx].detach().numpy()
    
    sns.heatmap(head_weights, ax=ax, cmap='viridis',
                xticklabels=False, yticklabels=False,
                cbar=True, square=True)
    ax.set_title(f'Head {head_idx + 1}')
    ax.set_xlabel('Key positions')
    ax.set_ylabel('Query positions')

plt.suptitle('Multi-Head Attention: Each Head Learns Different Patterns', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nNotice how different heads have different attention patterns!")
print("Some heads focus on nearby tokens, others on distant ones.")
print("Some are sparse, others spread attention more uniformly.")
print("\nThis diversity is what makes multi-head attention so powerful!")

## Using PyTorch's Optimized Attention

### Why Use PyTorch's Built-in Functions?

While it's important to understand attention from scratch, PyTorch provides **highly optimized** implementations:

**`F.scaled_dot_product_attention`** (PyTorch 2.0+):
- **Flash Attention**: Memory-efficient attention algorithm
- **Kernel fusion**: Optimized CUDA kernels
- **Automatic selection**: Chooses best implementation for your hardware
- **Up to 10x faster** than naive implementations

**When to use what**:
- **Learning**: Use from-scratch implementation
- **Production**: Use PyTorch's optimized functions
- **Research**: Understand both!

In [None]:
# PyTorch 2.0+ has optimized attention
class OptimizedMultiHeadAttention(nn.Module):
    """
    Multi-head attention using PyTorch's optimized implementation.
    
    This is what you should use in production!
    """
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(OptimizedMultiHeadAttention, self).__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        
        self.dropout = dropout
    
    def forward(self, X, mask=None):
        batch_size, seq_len, d_model = X.size()
        
        # Project and reshape
        Q = self.W_Q(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(X).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Use PyTorch's optimized attention
        # This automatically uses Flash Attention when available!
        attn_output = F.scaled_dot_product_attention(
            Q, K, V,
            attn_mask=mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False  # Set to True for causal masking (GPT-style)
        )
        
        # Reshape and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        output = self.W_O(attn_output)
        
        return output

# Compare implementations
optimized_mha = OptimizedMultiHeadAttention(d_model=512, num_heads=8)
X_test = torch.randn(2, 10, 512)

output_optimized = optimized_mha(X_test)

print("Optimized Multi-Head Attention:")
print(f"Output shape: {output_optimized.shape}")
print("\nBenefits of PyTorch's implementation:")
print("✓ Flash Attention: O(N) memory instead of O(N²)")
print("✓ Fused kernels: Fewer memory transfers")
print("✓ Hardware-specific optimizations")
print("✓ Automatic backward pass optimization")
print("\nIn practice: Use this for training large models!")

## Causal (Masked) Attention for Language Modeling

### Why Do We Need Masking?

In language modeling (e.g., GPT), we want to predict the next token. **Problem**: We can't let the model cheat by looking at future tokens!

**Example**:
```
Sentence: "The cat sat on the mat"

When predicting "sat":
  Can attend to: "The", "cat"
  CANNOT attend to: "on", "the", "mat" (these are in the future!)
```

**Solution: Causal masking**
- Mask out future positions by setting their attention scores to -∞
- After softmax, -∞ becomes 0 (no attention)
- Creates a lower-triangular attention pattern

**Why this is critical**:
- **Training**: Prevents information leakage from future tokens
- **Autoregressive generation**: Enables next-token prediction
- **Foundation of GPT**: This masking is what makes GPT-style models work

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal (lower-triangular) mask.
    
    Why lower-triangular? Token i can only attend to tokens 0...i
    
    Returns:
        mask: (seq_len, seq_len) with 1s in lower triangle, 0s above
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# Visualize causal mask
seq_len = 8
causal_mask = create_causal_mask(seq_len)

plt.figure(figsize=(8, 8))
sns.heatmap(causal_mask.numpy(), annot=True, fmt='.0f',
            cmap='RdYlGn', cbar=False, square=True,
            xticklabels=[f'T{i}' for i in range(seq_len)],
            yticklabels=[f'T{i}' for i in range(seq_len)])
plt.title('Causal Mask for GPT-style Models\n(1 = can attend, 0 = masked)', fontsize=14)
plt.xlabel('Key positions (what can be attended to)')
plt.ylabel('Query positions (current token)')
plt.tight_layout()
plt.show()

print("How to read this:")
print("- Row 0 (token 0): Can only attend to itself")
print("- Row 3 (token 3): Can attend to tokens 0, 1, 2, 3")
print("- Upper triangle is masked: Can't see the future!")
print("\nThis is exactly how GPT generates text:")
print("Each token prediction can only use past context.")

In [None]:
# Apply causal attention
seq_len = 8
d_model = 64

# Create input
X = torch.randn(1, seq_len, d_model)

# Create causal mask (1, 1, seq_len, seq_len)
causal_mask = create_causal_mask(seq_len).unsqueeze(0).unsqueeze(0)

# Apply multi-head attention with causal mask
mha = MultiHeadAttention(d_model, num_heads=4)
output, attn_weights = mha(X, mask=causal_mask)

# Visualize one head's causal attention
plt.figure(figsize=(10, 8))
sns.heatmap(attn_weights[0, 0].detach().numpy(),
            annot=True, fmt='.3f', cmap='Blues',
            xticklabels=[f'T{i}' for i in range(seq_len)],
            yticklabels=[f'T{i}' for i in range(seq_len)],
            cbar_kws={'label': 'Attention Weight'})
plt.title('Causal Attention Weights (Head 1)\nUpper triangle is zero!', fontsize=14)
plt.xlabel('Key positions')
plt.ylabel('Query positions')
plt.tight_layout()
plt.show()

print("Notice:")
print("- Upper triangle is all zeros (masked future)")
print("- Each row still sums to 1.0 (valid probability distribution)")
print("- This pattern enables autoregressive generation!")

## Mini Exercises

Test your understanding with these exercises.

### Exercise 1: Attention Score Calculation

Given:
```python
Q = torch.tensor([[1.0, 0.0, 0.0]])
K = torch.tensor([[1.0, 0.0, 0.0],
                  [0.0, 1.0, 0.0],
                  [0.0, 0.0, 1.0]])
```

Calculate the attention weights (before softmax) with scaling. What does this tell you about attention?

In [None]:
# YOUR CODE HERE


# SOLUTION
def show_solution_1():
    Q = torch.tensor([[1.0, 0.0, 0.0]])
    K = torch.tensor([[1.0, 0.0, 0.0],
                      [0.0, 1.0, 0.0],
                      [0.0, 0.0, 1.0]])
    
    d_k = K.size(-1)
    
    # Compute scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    print("Attention scores (before softmax):")
    print(scores)
    
    # Apply softmax
    weights = F.softmax(scores, dim=-1)
    print("\nAttention weights (after softmax):")
    print(weights)
    
    print("\nInterpretation:")
    print("- Q is [1,0,0] and K[0] is [1,0,0]: Perfect alignment → highest score")
    print("- Q is [1,0,0] but K[1] is [0,1,0]: Orthogonal → score of 0")
    print("- Dot product measures similarity/alignment of vectors")
    print("- Softmax converts scores to probabilities")

# Uncomment to see solution:
# show_solution_1()

### Exercise 2: Implement Additive Attention

The original attention mechanism (Bahdanau et al., 2015) used **additive attention** instead of dot-product:

$$\text{score}(Q, K) = v^T \tanh(W_Q Q + W_K K)$$

Implement additive attention and compare with dot-product attention. Why did dot-product win?

In [None]:
# YOUR CODE HERE
class AdditiveAttention(nn.Module):
    def __init__(self, d_model):
        super(AdditiveAttention, self).__init__()
        # Define W_Q, W_K, and v
        pass
    
    def forward(self, Q, K, V, mask=None):
        # Implement additive attention
        pass


# SOLUTION
def show_solution_2():
    class AdditiveAttention(nn.Module):
        def __init__(self, d_model):
            super(AdditiveAttention, self).__init__()
            self.W_Q = nn.Linear(d_model, d_model, bias=False)
            self.W_K = nn.Linear(d_model, d_model, bias=False)
            self.v = nn.Linear(d_model, 1, bias=False)
        
        def forward(self, Q, K, V, mask=None):
            # Q: (batch, seq_len_q, d_model)
            # K: (batch, seq_len_k, d_model)
            
            # Project Q and K
            Q_proj = self.W_Q(Q).unsqueeze(2)  # (batch, seq_len_q, 1, d_model)
            K_proj = self.W_K(K).unsqueeze(1)  # (batch, 1, seq_len_k, d_model)
            
            # Additive combination
            combined = torch.tanh(Q_proj + K_proj)  # (batch, seq_len_q, seq_len_k, d_model)
            
            # Score
            scores = self.v(combined).squeeze(-1)  # (batch, seq_len_q, seq_len_k)
            
            if mask is not None:
                scores = scores.masked_fill(mask == 0, float('-inf'))
            
            weights = F.softmax(scores, dim=-1)
            output = torch.matmul(weights, V)
            
            return output, weights
    
    print("Additive vs Dot-Product Attention:")
    print("\nAdditive Attention:")
    print("  Pros: More expressive (learned non-linearity)")
    print("  Cons: Slower (more parameters, complex computation)")
    print("\nDot-Product Attention:")
    print("  Pros: Faster (matrix multiply is highly optimized)")
    print("  Pros: Fewer parameters")
    print("  Cons: Less flexible (just dot product)")
    print("\nWhy dot-product won:")
    print("  - Speed matters for large models (billions of parameters)")
    print("  - Multi-head attention adds expressiveness")
    print("  - Hardware optimization (GPUs excel at matrix multiply)")

# Uncomment to see solution:
# show_solution_2()

### Exercise 3: Attention for Sequence Classification

Build a simple sequence classifier that uses self-attention to aggregate token representations. Use attention pooling instead of taking the last token or averaging all tokens.

In [None]:
# YOUR CODE HERE
class AttentionClassifier(nn.Module):
    def __init__(self, d_model, num_classes):
        super(AttentionClassifier, self).__init__()
        # Add your code
        pass
    
    def forward(self, x):
        # Add your code
        pass


# SOLUTION
def show_solution_3():
    class AttentionPooling(nn.Module):
        """
        Attention-based pooling: Learn to attend to important tokens.
        """
        def __init__(self, d_model):
            super(AttentionPooling, self).__init__()
            # Learnable query vector
            self.query = nn.Parameter(torch.randn(1, 1, d_model))
            self.W_K = nn.Linear(d_model, d_model)
            self.W_V = nn.Linear(d_model, d_model)
        
        def forward(self, x):
            # x: (batch, seq_len, d_model)
            batch_size = x.size(0)
            
            # Expand query for batch
            Q = self.query.expand(batch_size, -1, -1)  # (batch, 1, d_model)
            K = self.W_K(x)  # (batch, seq_len, d_model)
            V = self.W_V(x)
            
            # Attention scores
            scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(K.size(-1))
            weights = F.softmax(scores, dim=-1)  # (batch, 1, seq_len)
            
            # Weighted sum
            pooled = torch.matmul(weights, V).squeeze(1)  # (batch, d_model)
            
            return pooled, weights
    
    class AttentionClassifier(nn.Module):
        def __init__(self, d_model, num_classes):
            super(AttentionClassifier, self).__init__()
            self.attention_pool = AttentionPooling(d_model)
            self.classifier = nn.Linear(d_model, num_classes)
        
        def forward(self, x):
            # Pool with attention
            pooled, weights = self.attention_pool(x)
            # Classify
            logits = self.classifier(pooled)
            return logits, weights
    
    # Test
    model = AttentionClassifier(d_model=64, num_classes=5)
    x = torch.randn(2, 10, 64)  # (batch=2, seq_len=10, d_model=64)
    logits, weights = model(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output logits: {logits.shape}")
    print(f"Attention weights: {weights.shape}")
    print("\nWhy attention pooling?")
    print("- Learns which tokens are important for classification")
    print("- More flexible than averaging or taking [CLS] token")
    print("- Provides interpretability (can visualize which tokens matter)")

# Uncomment to see solution:
# show_solution_3()

## Comprehensive Exercise: Build a Complete Attention Module

Build a production-ready attention module that supports:
1. Multi-head self-attention
2. Optional causal masking
3. Dropout for regularization
4. Residual connection and layer normalization

This is what you'd actually use in a transformer!

In [None]:
# YOUR CODE HERE


# SOLUTION
def show_comprehensive_solution():
    class TransformerAttentionBlock(nn.Module):
        """
        Complete attention block as used in transformers.
        
        Includes:
        - Multi-head attention
        - Dropout
        - Residual connection
        - Layer normalization
        """
        def __init__(self, d_model, num_heads, dropout=0.1):
            super(TransformerAttentionBlock, self).__init__()
            
            self.attention = MultiHeadAttention(d_model, num_heads)
            self.norm = nn.LayerNorm(d_model)
            self.dropout = nn.Dropout(dropout)
        
        def forward(self, x, mask=None):
            """
            Args:
                x: (batch, seq_len, d_model)
                mask: Optional attention mask
            
            Returns:
                output: (batch, seq_len, d_model)
            """
            # Multi-head attention
            attn_output, _ = self.attention(x, mask)
            
            # Dropout
            attn_output = self.dropout(attn_output)
            
            # Residual connection + Layer norm (Post-LN)
            # Why residual? Helps gradients flow, enables deep networks
            # Why layer norm? Stabilizes training
            output = self.norm(x + attn_output)
            
            return output
    
    # Test the complete block
    block = TransformerAttentionBlock(d_model=512, num_heads=8, dropout=0.1)
    x = torch.randn(2, 10, 512)
    
    output = block(x)
    
    print("Complete Transformer Attention Block:")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print("\nThis block includes everything needed for transformers:")
    print("✓ Multi-head attention (parallel attention patterns)")
    print("✓ Dropout (regularization)")
    print("✓ Residual connection (gradient flow)")
    print("✓ Layer normalization (training stability)")
    print("\nStack multiple blocks → You have a transformer!")

# Uncomment to see solution:
# show_comprehensive_solution()

## Key Takeaways

### Core Concepts

**1. Why attention was invented**:
- Solves the fixed-length bottleneck in sequence-to-sequence models
- Allows dynamic focus on relevant parts of input
- Enables long-range dependencies without vanishing gradients

**2. The Query-Key-Value paradigm**:
- **Query**: What information to look for
- **Key**: What information is available (searchable index)
- **Value**: The actual information to retrieve
- Separation enables flexible, learned information retrieval

**3. Scaled dot-product attention**:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
- Dot product measures similarity
- Scaling prevents saturation
- Softmax normalizes to probabilities
- Multiply by V retrieves information

**4. Self-attention vs Cross-attention**:
- **Self-attention**: Sequence attends to itself (Q, K, V from same source)
- **Cross-attention**: One sequence attends to another (Q from one, K/V from another)

**5. Multi-head attention**:
- Multiple attention patterns in parallel
- Each head can specialize (syntax, semantics, long-range)
- Ensemble effect improves representation

**6. Causal masking**:
- Prevents attending to future tokens
- Essential for autoregressive models (GPT)
- Creates lower-triangular attention pattern

### Connection to Modern AI

**Transformers = Attention + FFN + Norm**:
- Attention is the core building block
- Next topic: Complete transformer architecture

**Used everywhere in 2025**:
- **Language models**: GPT, LLaMA, Claude
- **Vision**: Vision transformers (ViT)
- **Multimodal**: CLIP, Flamingo, GPT-4V
- **Audio**: Whisper, MusicGen
- **Protein folding**: AlphaFold

**Why attention cannot be skipped**:
- Foundation of all modern deep learning
- Enables transfer learning at scale
- Key to understanding any state-of-the-art model

### What's Next?

You've mastered attention! Next:
- **Positional encoding**: How to inject position information
- **Transformer architecture**: Putting it all together
- **Advanced attention**: Flash attention, GQA, MoE

Understanding attention deeply is your ticket to understanding modern AI!

## Further Reading

### Essential Papers
1. **Bahdanau et al. (2015)**: "Neural Machine Translation by Jointly Learning to Align and Translate" (Original attention)
2. **Vaswani et al. (2017)**: "Attention is All You Need" (Transformers, multi-head attention)
3. **Dao et al. (2022)**: "FlashAttention: Fast and Memory-Efficient Exact Attention" (Modern optimization)

### Tutorials and Visualizations
4. **The Illustrated Transformer** (Jay Alammar): http://jalammar.github.io/illustrated-transformer/
5. **Attention? Attention!** (Lilian Weng): https://lilianweng.github.io/posts/2018-06-24-attention/

### Advanced Topics
6. **Grouped Query Attention** (GQA): Efficient attention for large models
7. **Flash Attention**: Memory-efficient attention computation
8. **Sparse Attention**: Reducing O(N²) complexity

### Implementation Resources
- PyTorch documentation: `torch.nn.MultiheadAttention`
- Hugging Face Transformers library
- Annotated Transformer (Harvard NLP): http://nlp.seas.harvard.edu/annotated-transformer/