# Task 8.1: Attention from Scratch

**Module:** 8 - Natural Language Processing & Transformers  
**Time:** 2 hours  
**Difficulty:** ⭐⭐⭐

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Understand what attention means and why it revolutionized NLP
- [ ] Implement scaled dot-product attention from scratch
- [ ] Build multi-head attention and understand why multiple heads matter
- [ ] Visualize attention patterns to see what the model "looks at"
- [ ] Apply masking for causal (autoregressive) models

---

## Prerequisites

- Completed: Module 6 (PyTorch Deep Learning)
- Knowledge of: Matrix multiplication, softmax, neural network basics

---

## Real-World Context

**Attention is the core innovation behind:**
- **Google Translate** - Understanding which words in "The cat sat on the mat" correspond to words in "Le chat s'est assis sur le tapis"
- **ChatGPT/Claude** - Knowing that "it" in "I dropped the glass and it broke" refers to the glass
- **GitHub Copilot** - Understanding context from your entire codebase, not just the current line

Before attention (2014), sequence models had to compress entire sentences into fixed-size vectors, losing information. Attention lets models dynamically focus on relevant parts of the input.

---

## ELI5: What is Attention?

> **Imagine you're at a crowded party.** Dozens of conversations are happening around you. Somehow, when someone across the room says your name, you instantly tune in to *that* conversation. Your brain didn't process every word equally—it learned to **attend** to what matters.
>
> **Now imagine reading this sentence:** "The animal didn't cross the street because it was too tired."
>
> What does "it" refer to? Your brain instantly knows it's the animal (not the street). How? You're **paying attention** to the relationship between words.
>
> **In AI terms:**
> - Each word asks: "Which other words should I pay attention to?"
> - The model calculates "attention scores" between all word pairs
> - Words with high attention scores influence each other more
> - This happens in parallel for all words simultaneously!

### The Key Insight

Before attention, models read sentences like a person with memory problems—by the time they reach the end, they've forgotten the beginning. Attention lets models look back at any part of the input at any time.

---

## Part 1: Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import numpy as np

# Set up plotting with seaborn if available
try:
    import seaborn as sns
    # Use set_theme instead of deprecated set_palette
    sns.set_theme(style="whitegrid", palette="husl")
    HAS_SEABORN = True
except ImportError:
    HAS_SEABORN = False
    print("⚠️ seaborn not installed. Using matplotlib defaults.")

# Check our hardware
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set random seed for reproducibility
torch.manual_seed(42)

---

## Part 2: Understanding Q, K, V (Query, Key, Value)

Before we implement attention, we need to understand the three key players: **Query**, **Key**, and **Value**.

### ELI5: The Library Analogy

> **Imagine a library search system:**
>
> - **Query (Q)**: Your search question - "books about space"
> - **Key (K)**: The labels on each book - "Astronomy 101", "Cooking for Beginners", "Galactic Adventures"
> - **Value (V)**: The actual content of each book
>
> **How it works:**
> 1. You compare your **Query** to all **Keys** (how relevant is each book?)
> 2. Books with high relevance get high scores
> 3. You blend all **Values** based on those scores
> 4. Result: A weighted summary focused on space-related content!

### In Mathematical Terms

```
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
```

Where:
- `Q @ K^T` computes similarity between queries and keys
- `sqrt(d_k)` prevents the dot products from getting too large
- `softmax` converts scores to probabilities (sum to 1)
- Multiply by `V` to get weighted values

In [None]:
# Let's build intuition with a simple example

# Suppose we have 4 words in a sentence
sentence = ["The", "cat", "sat", "down"]

# Each word is represented as a vector (embedding)
# In practice, these come from an embedding layer
# For now, let's use random vectors of dimension 8

seq_len = 4
d_model = 8  # Embedding dimension

# Random embeddings for our words
embeddings = torch.randn(seq_len, d_model)
print(f"Word embeddings shape: {embeddings.shape}")
print(f"Each word is a vector of {d_model} dimensions")

In [None]:
# In self-attention, Q, K, and V all come from the same input
# but are transformed by different weight matrices

# Create learnable projections
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

# Project embeddings to Q, K, V
Q = W_q(embeddings)  # What am I looking for?
K = W_k(embeddings)  # What do I contain?
V = W_v(embeddings)  # What's my actual content?

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")

### What Just Happened?

We created three different "views" of our input:
1. **Q (Query)**: "What information am I looking for?"
2. **K (Key)**: "What information do I have to offer?"
3. **V (Value)**: "Here's my actual content"

These are learned during training—the model figures out the best way to project words for attention!

---

## Part 3: Scaled Dot-Product Attention

Now let's implement the core attention mechanism step by step.

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query tensor of shape (..., seq_len, d_k)
        K: Key tensor of shape (..., seq_len, d_k)  
        V: Value tensor of shape (..., seq_len, d_v)
        mask: Optional mask tensor (True = keep, False = mask out)
    
    Returns:
        output: Attention-weighted values
        attention_weights: The attention pattern (for visualization)
    """
    # Get the dimension of keys for scaling
    d_k = K.size(-1)
    
    # Step 1: Compute attention scores
    # Q @ K^T gives us how much each query attends to each key
    scores = torch.matmul(Q, K.transpose(-2, -1))
    print(f"  Raw attention scores shape: {scores.shape}")
    
    # Step 2: Scale by sqrt(d_k)
    # Without scaling, large d_k leads to huge dot products,
    # pushing softmax into regions with tiny gradients
    scores = scores / math.sqrt(d_k)
    print(f"  Scaled scores range: [{scores.min():.2f}, {scores.max():.2f}]")
    
    # Step 3: Apply mask (if provided)
    if mask is not None:
        # Replace masked positions with -inf (becomes 0 after softmax)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 4: Softmax to get attention weights (probabilities)
    attention_weights = F.softmax(scores, dim=-1)
    print(f"  Attention weights sum per row: {attention_weights.sum(dim=-1)}")
    
    # Step 5: Apply attention to values
    output = torch.matmul(attention_weights, V)
    print(f"  Output shape: {output.shape}")
    
    return output, attention_weights

In [None]:
# Run attention on our example
print("Computing attention:")
output, attention = scaled_dot_product_attention(Q, K, V)

print(f"\nInput shape: {embeddings.shape}")
print(f"Output shape: {output.shape}")
print("\n✅ Output has same shape as input - each word now contains")
print("   information from all other words, weighted by attention!")

### Visualizing Attention Patterns

One of the beautiful things about attention is we can visualize what the model is "looking at"!

In [None]:
def plot_attention(attention_weights, x_labels, y_labels, title="Attention Pattern"):
    """
    Visualize attention weights as a heatmap.

    Args:
        attention_weights: Tensor of shape (seq_len, seq_len)
        x_labels: Labels for keys (columns)
        y_labels: Labels for queries (rows)
        title: Plot title
    """
    fig, ax = plt.subplots(figsize=(8, 6))

    # Convert to numpy for plotting
    weights = attention_weights.detach().numpy()

    # Create heatmap - use seaborn if available, otherwise matplotlib
    if HAS_SEABORN:
        sns.heatmap(
            weights,
            xticklabels=x_labels,
            yticklabels=y_labels,
            annot=True,
            fmt=".2f",
            cmap="Blues",
            ax=ax,
            vmin=0,
            vmax=1
        )
    else:
        im = ax.imshow(weights, cmap="Blues", vmin=0, vmax=1, aspect='auto')
        ax.set_xticks(range(len(x_labels)))
        ax.set_xticklabels(x_labels)
        ax.set_yticks(range(len(y_labels)))
        ax.set_yticklabels(y_labels)
        # Add annotations
        for i in range(len(y_labels)):
            for j in range(len(x_labels)):
                ax.text(j, i, f"{weights[i, j]:.2f}", ha='center', va='center', fontsize=8)
        plt.colorbar(im, ax=ax)

    ax.set_xlabel("Keys (attending to)")
    ax.set_ylabel("Queries (from)")
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

# Visualize our attention pattern
plot_attention(
    attention,
    sentence,
    sentence,
    "Self-Attention Pattern (Random Weights)"
)

### What the Heatmap Shows

- **Rows** = Query words ("I am looking...")
- **Columns** = Key words ("...at these words")
- **Cell value** = How much the query word attends to the key word
- Each row sums to 1.0 (it's a probability distribution)

With random weights, the attention is essentially random. After training, meaningful patterns emerge!

---

## Part 4: Why Scale by sqrt(d_k)?

This is a common interview question! Let's understand why scaling matters.

In [None]:
# Experiment: What happens without scaling?

def demonstrate_scaling():
    """Show why scaling matters for attention stability."""
    
    dimensions = [8, 64, 512, 2048]
    
    print("Effect of dimension on dot product magnitude:")
    print("=" * 60)
    
    for d_k in dimensions:
        # Random vectors with unit variance
        q = torch.randn(1, d_k)
        k = torch.randn(1, d_k)
        
        # Dot product without scaling
        raw_score = (q @ k.T).item()
        
        # Dot product with scaling
        scaled_score = raw_score / math.sqrt(d_k)
        
        print(f"d_k={d_k:4d}: raw_score={raw_score:8.2f}, "
              f"scaled_score={scaled_score:6.2f}, "
              f"sqrt(d_k)={math.sqrt(d_k):.1f}")

demonstrate_scaling()

In [None]:
# What does this mean for softmax?

def demonstrate_softmax_saturation():
    """Show how large values cause softmax to saturate."""
    
    # Simulated attention scores
    small_scores = torch.tensor([1.0, 2.0, 3.0, 4.0])
    large_scores = torch.tensor([10.0, 20.0, 30.0, 40.0])
    
    small_probs = F.softmax(small_scores, dim=-1)
    large_probs = F.softmax(large_scores, dim=-1)
    
    print("Softmax behavior with different score magnitudes:")
    print("=" * 50)
    print(f"Small scores: {small_scores.tolist()}")
    print(f"Softmax:      {small_probs.tolist()}")
    print(f"Gradient OK: Probabilities are distributed")
    print()
    print(f"Large scores: {large_scores.tolist()}")
    print(f"Softmax:      {large_probs.tolist()}")
    print(f"Problem: One element takes ~100%, gradients vanish!")

demonstrate_softmax_saturation()

### The Intuition

Without scaling:
- As dimension increases, dot products get larger (variance = d_k)
- Large dot products push softmax to extremes (0 or 1)
- Extreme softmax = tiny gradients = slow/no learning

With scaling by √d_k:
- Dot products have variance ≈ 1 regardless of dimension
- Softmax stays in its "interesting" region
- Gradients flow nicely, training works!

---

## Part 5: Multi-Head Attention

### ELI5: Why Multiple Heads?

> **Imagine you're analyzing a movie scene.** You might notice:
> - The **dialogue** (what characters say)
> - The **cinematography** (camera angles, lighting)
> - The **music** (emotional cues)
> - The **acting** (facial expressions, body language)
>
> **Each "head" looks at the same scene differently!**
>
> In attention:
> - One head might focus on syntax (grammar relationships)
> - Another on semantics (meaning relationships)
> - Another on coreference (what "it" refers to)
> - Another on position (nearby words)
>
> Multi-head attention = multiple perspectives combined!

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism.
    
    Instead of one attention function, we run h parallel attention "heads"
    and concatenate their outputs.
    """
    
    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model: Model dimension (embedding size)
            num_heads: Number of attention heads
        """
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head
        
        # Linear projections for Q, K, V (all heads combined)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Final output projection
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        """
        Compute multi-head attention.
        
        Args:
            query: Query tensor (batch, seq_len, d_model)
            key: Key tensor (batch, seq_len, d_model)
            value: Value tensor (batch, seq_len, d_model)
            mask: Optional attention mask
            
        Returns:
            output: Attention output (batch, seq_len, d_model)
            attention_weights: Attention patterns per head
        """
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # Step 1: Linear projections
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Step 2: Reshape for multi-head: (batch, heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Step 3: Scaled dot-product attention (for all heads at once!)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        context = torch.matmul(attention_weights, V)  # (batch, heads, seq_len, d_k)
        
        # Step 4: Concatenate heads: (batch, seq_len, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Step 5: Final projection
        output = self.W_o(context)
        
        return output, attention_weights

# Create multi-head attention
d_model = 64
num_heads = 8
mha = MultiHeadAttention(d_model, num_heads)

print(f"Multi-Head Attention Configuration:")
print(f"  d_model (total): {d_model}")
print(f"  num_heads: {num_heads}")
print(f"  d_k (per head): {d_model // num_heads}")

In [None]:
# Test with a batch of sentences
batch_size = 2
seq_len = 6

# Random embeddings (in practice, these come from an embedding layer)
x = torch.randn(batch_size, seq_len, d_model)

# Self-attention: query, key, value are all the same input
output, attention = mha(x, x, x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention.shape}")
print(f"  (batch={batch_size}, heads={num_heads}, seq={seq_len}, seq={seq_len})")

In [None]:
# Visualize attention from different heads
words = ["The", "quick", "brown", "fox", "jumps", "."]

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for head_idx, ax in enumerate(axes.flat):
    # Get attention for first example, specific head
    head_attention = attention[0, head_idx].detach().numpy()

    if HAS_SEABORN:
        sns.heatmap(
            head_attention,
            xticklabels=words,
            yticklabels=words,
            cmap="Blues",
            ax=ax,
            cbar=False,
            vmin=0,
            vmax=1
        )
    else:
        im = ax.imshow(head_attention, cmap="Blues", vmin=0, vmax=1, aspect='auto')
        ax.set_xticks(range(len(words)))
        ax.set_xticklabels(words)
        ax.set_yticks(range(len(words)))
        ax.set_yticklabels(words)
    ax.set_title(f"Head {head_idx + 1}")
    ax.tick_params(axis='x', rotation=45)

plt.suptitle("Different Attention Heads See Different Patterns\n(Random weights - patterns emerge after training)", 
             fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### What Just Happened?

Each head has its own Q, K, V projections and learns to attend differently:
- One head might learn to look at the previous word
- Another might learn to look at verbs from subjects
- Another might learn syntactic dependencies

After training on real data, these patterns become meaningful!

---

## Part 6: Causal Masking (for GPT-style models)

### ELI5: Why Mask?

> **Imagine you're taking a fill-in-the-blank test:**
> "The sun rises in the ____"
>
> **If you could see the answer key, that would be cheating!**
>
> For language models that predict the next word, we need to prevent them from "peeking ahead." This is called **causal** or **autoregressive** masking.
>
> When predicting word 5, the model can only see words 1-4, not words 6+.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal (look-ahead) mask.
    
    Position i can only attend to positions <= i.
    
    Returns:
        mask: Boolean tensor where True = keep, False = mask
    """
    # Lower triangular matrix
    mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
    return mask

# Create and visualize causal mask
seq_len = 6
causal_mask = create_causal_mask(seq_len)

print("Causal Mask (True = can attend, False = masked):")
print(causal_mask.int())

# Visualize
plt.figure(figsize=(6, 5))
if HAS_SEABORN:
    sns.heatmap(
        causal_mask.int().numpy(),
        annot=True,
        fmt="d",
        cmap="RdYlGn",
        xticklabels=[f"pos {i}" for i in range(seq_len)],
        yticklabels=[f"pos {i}" for i in range(seq_len)],
        cbar=False
    )
else:
    plt.imshow(causal_mask.int().numpy(), cmap="RdYlGn", aspect='auto')
    plt.xticks(range(seq_len), [f"pos {i}" for i in range(seq_len)])
    plt.yticks(range(seq_len), [f"pos {i}" for i in range(seq_len)])
    # Add annotations
    for i in range(seq_len):
        for j in range(seq_len):
            plt.text(j, i, str(causal_mask[i, j].int().item()), ha='center', va='center')
plt.title("Causal Mask\n(Green=1=can see, Red=0=cannot see)")
plt.xlabel("Key positions")
plt.ylabel("Query positions")
plt.tight_layout()
plt.show()

In [None]:
# Apply causal masking to attention
def causal_attention_demo():
    """Demonstrate causal vs. non-causal attention."""
    
    # Input sequence
    words = ["I", "love", "machine", "learning", "!", "<END>"]
    seq_len = len(words)
    d_model = 32
    
    # Random embeddings
    x = torch.randn(1, seq_len, d_model)
    
    # Create Q, K, V
    Q = x
    K = x
    V = x
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_model)
    
    # Create causal mask
    mask = create_causal_mask(seq_len)
    
    # Apply mask
    masked_scores = scores.masked_fill(~mask.unsqueeze(0), float('-inf'))
    
    # Softmax
    attention_no_mask = F.softmax(scores, dim=-1)
    attention_with_mask = F.softmax(masked_scores, dim=-1)
    
    # Visualize comparison
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    if HAS_SEABORN:
        sns.heatmap(
            attention_no_mask[0].detach().numpy(),
            xticklabels=words,
            yticklabels=words,
            annot=True,
            fmt=".2f",
            cmap="Blues",
            ax=axes[0]
        )
        sns.heatmap(
            attention_with_mask[0].detach().numpy(),
            xticklabels=words,
            yticklabels=words,
            annot=True,
            fmt=".2f",
            cmap="Blues",
            ax=axes[1]
        )
    else:
        # Matplotlib fallback
        for idx, (attn, ax) in enumerate([(attention_no_mask[0], axes[0]), (attention_with_mask[0], axes[1])]):
            attn_np = attn.detach().numpy()
            im = ax.imshow(attn_np, cmap="Blues", vmin=0, vmax=1, aspect='auto')
            ax.set_xticks(range(len(words)))
            ax.set_xticklabels(words, rotation=45)
            ax.set_yticks(range(len(words)))
            ax.set_yticklabels(words)
            for i in range(len(words)):
                for j in range(len(words)):
                    ax.text(j, i, f"{attn_np[i, j]:.2f}", ha='center', va='center', fontsize=7)
    
    axes[0].set_title("Bidirectional Attention (BERT-style)\nEach word sees ALL words")
    axes[1].set_title("Causal Attention (GPT-style)\nEach word only sees past words")
    
    plt.tight_layout()
    plt.show()

causal_attention_demo()

### Key Difference:

- **Bidirectional (BERT)**: Word at position 3 can attend to all 6 words
- **Causal (GPT)**: Word at position 3 can only attend to positions 0, 1, 2, 3

This prevents "information leakage" during training for language models!

---

## Part 7: Attention Types Overview

There are several types of attention used in different contexts:

In [None]:
# Create a comprehensive diagram of attention types

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Self-Attention
self_attn = torch.tril(torch.ones(5, 5)) + torch.triu(torch.ones(5, 5))
self_attn = (self_attn > 0).float()

# Causal Self-Attention
causal = torch.tril(torch.ones(5, 5))

# Cross-Attention
cross = torch.ones(5, 7)

if HAS_SEABORN:
    sns.heatmap(self_attn.numpy(), ax=axes[0], cmap="Blues", cbar=False, 
                annot=False, square=True)
    sns.heatmap(causal.numpy(), ax=axes[1], cmap="Blues", cbar=False, 
                annot=False, square=True)
    sns.heatmap(cross.numpy(), ax=axes[2], cmap="Blues", cbar=False, 
                annot=False)
else:
    axes[0].imshow(self_attn.numpy(), cmap="Blues", aspect='auto')
    axes[1].imshow(causal.numpy(), cmap="Blues", aspect='auto')
    axes[2].imshow(cross.numpy(), cmap="Blues", aspect='auto')

axes[0].set_title("Self-Attention (Encoder)\nAll positions see all positions", fontsize=11)
axes[0].set_xlabel("Keys (same sequence)")
axes[0].set_ylabel("Queries (same sequence)")

axes[1].set_title("Causal Self-Attention (Decoder)\nEach position sees only past", fontsize=11)
axes[1].set_xlabel("Keys (same sequence)")
axes[1].set_ylabel("Queries (same sequence)")

axes[2].set_title("Cross-Attention (Encoder-Decoder)\nDecoder queries attend to encoder keys", fontsize=11)
axes[2].set_xlabel("Keys (encoder sequence)")
axes[2].set_ylabel("Queries (decoder sequence)")

plt.tight_layout()
plt.show()

print("Attention Types Summary:")
print("=" * 60)
print("1. Self-Attention: Q, K, V all from same sequence")
print("   - Used in: BERT encoder, GPT (with mask), ViT")
print()
print("2. Causal Self-Attention: Self-attention + future mask")
print("   - Used in: GPT, LLaMA, text generation")
print()
print("3. Cross-Attention: Q from decoder, K/V from encoder")
print("   - Used in: T5, translation, encoder-decoder models")

---

## Try It Yourself: Exercises

### Exercise 1: Implement Attention with Dropout

Add dropout to the attention weights (commonly used to prevent overfitting).

<details>
<summary>Hint</summary>
Apply dropout AFTER softmax but BEFORE multiplying with V.
</details>

In [None]:
def scaled_dot_product_attention_with_dropout(Q, K, V, dropout_p=0.1, mask=None, training=True):
    """
    Scaled dot-product attention with dropout.
    
    TODO: Implement attention with dropout on the attention weights.
    
    Args:
        Q, K, V: Query, Key, Value tensors
        dropout_p: Dropout probability
        mask: Optional attention mask
        training: Whether in training mode (dropout active)
    """
    d_k = K.size(-1)
    
    # YOUR CODE HERE
    # 1. Compute attention scores
    # 2. Scale by sqrt(d_k)
    # 3. Apply mask if provided
    # 4. Apply softmax
    # 5. Apply dropout to attention weights
    # 6. Multiply with values
    
    pass  # Replace with your implementation

# Test your implementation
# Q = torch.randn(1, 4, 8)
# K = torch.randn(1, 4, 8)
# V = torch.randn(1, 4, 8)
# output, attention = scaled_dot_product_attention_with_dropout(Q, K, V, dropout_p=0.1)

### Exercise 2: Attention Complexity Analysis

What is the time and space complexity of self-attention? Why is this a problem for long sequences?

In [None]:
def analyze_attention_complexity():
    """
    Measure memory usage for different sequence lengths.
    
    TODO: 
    1. Create attention matrices for sequences of length 128, 512, 1024, 2048
    2. Measure memory usage
    3. Plot the relationship
    """
    # YOUR CODE HERE
    pass

# Your analysis here

### Exercise 3: Visualize Trained Attention

Let's load a pre-trained model and visualize what its attention heads learned!

In [None]:
# We'll explore this more in later notebooks, but here's a preview:

# TODO: Load a pre-trained BERT model and extract attention weights
# from transformers import BertModel, BertTokenizer
# 
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
# 
# text = "The cat sat on the mat because it was tired."
# inputs = tokenizer(text, return_tensors='pt')
# outputs = model(**inputs)
# 
# # Extract attention from the last layer
# attention = outputs.attentions[-1]  # (batch, heads, seq, seq)

print("Exercise: Uncomment and run the code above to visualize BERT attention!")
print("You'll see how different heads learn to attend to different relationships.")

---

## Common Mistakes

### Mistake 1: Forgetting to Scale

In [None]:
# Wrong: No scaling
def attention_wrong(Q, K, V):
    scores = torch.matmul(Q, K.transpose(-2, -1))  # Missing / sqrt(d_k)!
    return F.softmax(scores, dim=-1) @ V

# Right: With scaling
def attention_right(Q, K, V):
    d_k = K.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)  # Scaled!
    return F.softmax(scores, dim=-1) @ V

print("Why it matters: Without scaling, attention becomes 'peaky' (almost one-hot)")
print("for large d_k, leading to vanishing gradients during training.")

### Mistake 2: Wrong Mask Application

In [None]:
# Wrong: Applying mask AFTER softmax
def attention_mask_wrong(Q, K, V, mask):
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(K.size(-1))
    attention = F.softmax(scores, dim=-1)
    attention = attention * mask  # Wrong! Softmax already computed
    return attention @ V

# Right: Applying mask BEFORE softmax
def attention_mask_right(Q, K, V, mask):
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(K.size(-1))
    scores = scores.masked_fill(mask == 0, float('-inf'))  # Before softmax!
    attention = F.softmax(scores, dim=-1)  # -inf becomes 0
    return attention @ V

print("Why it matters: Masking after softmax doesn't make the row sum to 1.")
print("Masking before softmax with -inf ensures proper probability distribution.")

### Mistake 3: Shape Mismatch in Multi-Head Attention

In [None]:
# Wrong: Forgetting to transpose back
def multihead_wrong(x, num_heads):
    batch, seq_len, d_model = x.shape
    d_k = d_model // num_heads
    
    # Reshape to heads
    x = x.view(batch, seq_len, num_heads, d_k)  # (batch, seq, heads, d_k)
    x = x.transpose(1, 2)  # (batch, heads, seq, d_k)
    
    # ... attention computation ...
    
    # Wrong: Direct view without transpose
    # output = x.view(batch, seq_len, d_model)  # Shape mismatch!
    pass

# Right: Proper transpose then contiguous then view
def multihead_right(x, num_heads):
    batch, seq_len, d_model = x.shape
    d_k = d_model // num_heads
    
    # Reshape to heads
    x = x.view(batch, seq_len, num_heads, d_k)
    x = x.transpose(1, 2)  # (batch, heads, seq, d_k)
    
    # ... attention computation ...
    
    # Right: transpose, contiguous, then view
    x = x.transpose(1, 2)  # (batch, seq, heads, d_k)
    x = x.contiguous()  # Make memory contiguous
    output = x.view(batch, seq_len, d_model)  # Works!
    return output

print("Why it matters: View requires contiguous memory.")
print("After transpose, call .contiguous() before .view()")

---

## Checkpoint

You've learned:
- ✅ The intuition behind attention (library search, party analogy)
- ✅ Query, Key, Value projections and their roles
- ✅ Scaled dot-product attention implementation
- ✅ Why scaling by √d_k prevents gradient problems
- ✅ Multi-head attention and why multiple perspectives help
- ✅ Causal masking for autoregressive models
- ✅ Different attention types (self, causal, cross)

---

## Challenge (Optional)

Implement **Relative Position Attention** where attention scores are modified based on the relative distance between positions:

```
score(i, j) = q_i · k_j + q_i · r_(i-j)
```

Where `r_(i-j)` is a learned embedding for relative position.

This is used in models like Transformer-XL and Music Transformer!

In [None]:
# Challenge: Implement relative position attention
# YOUR CODE HERE

---

## Further Reading

- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/) - Best visual guide
- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) - Original paper
- [The Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/) - Code walkthrough
- [BertViz](https://github.com/jessevig/bertviz) - Interactive attention visualization

---

## Cleanup

In [None]:
# Clear GPU memory
import gc

# Delete large tensors
del mha, x, output, attention

# Clear cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

print("Memory cleared! Ready for the next notebook.")

---

## Next Up

In **Notebook 02: Transformer Block**, we'll combine attention with:
- Feed-forward networks
- Layer normalization
- Residual connections

...to build a complete Transformer encoder layer!

---

*Great job completing your first deep dive into attention! This is the foundation of modern NLP.*