# Topic 11: Flash Attention - Optimizing Attention for Speed and Memory

## Learning Objectives

By the end of this notebook, you will:
- Understand why standard attention is memory-inefficient
- Learn the Flash Attention algorithm and its optimizations
- Use PyTorch 2.5+ FlexAttention API for flexible attention patterns
- Leverage cuDNN Fused Flash Attention for H100 GPUs (75% speedup)
- Know when to apply Flash Attention in your models
- Understand the trade-offs and performance characteristics

---

## 1. The Big Picture: Why Flash Attention?

### The Attention Memory Problem

Standard attention has revolutionized AI, but it has a **critical bottleneck**: **O(n²) memory complexity**.

**What does this mean?**
- For a sequence of length n=1024 tokens
- Attention matrix size: 1024 × 1024 = **1,048,576 elements**
- For n=4096 (common in modern LLMs): **16,777,216 elements**
- For n=16384 (long context): **268,435,456 elements** (1GB+ just for attention scores!)

**Why is this a problem?**
1. **GPU Memory Limited**: Modern GPUs have 40-80GB VRAM; long sequences eat it all
2. **Memory Bandwidth Bottleneck**: Moving data between GPU memory levels is slow
3. **Limits Context Length**: Can't fit long documents/conversations

### Enter Flash Attention

**Flash Attention** solves this by:
1. **Tiling**: Processing attention in smaller blocks that fit in fast SRAM
2. **Recomputation**: Recalculating values instead of storing them (compute vs memory trade-off)
3. **Kernel Fusion**: Combining operations to minimize memory reads/writes

**Result**: Same output, but **2-4x faster** and **10-20x less memory**!

### Real-World Impact
- **GPT-4**: Uses Flash Attention for 32k+ context windows
- **LLaMA 3**: Enables 128k context with Flash Attention 2
- **Claude 3**: 200k context possible with optimized attention
- **Your models**: Train larger batches, longer sequences

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import time
from torch.nn.attention import SDPBackend, sdpa_kernel

# Check PyTorch version
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

---

## 2. Standard Attention: Understanding the Bottleneck

### Attention Formula Recap

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

### Memory Analysis: Step by Step

For sequence length **n** and embedding dimension **d**:

1. **Input**: Q, K, V each are (n, d) → **3nd** memory
2. **Attention Scores**: QK^T is (n, n) → **n²** memory ⚠️
3. **Softmax**: Still (n, n) → **n²** memory ⚠️
4. **Output**: (n, n) @ (n, d) → **nd** memory

**Total**: O(n² + nd) ≈ **O(n²)** for large n

**The bottleneck**: Storing the (n, n) attention matrix!

In [None]:
def standard_attention(Q, K, V, mask=None):
    """
    Standard attention implementation (memory-inefficient but clear)
    
    Args:
        Q: (batch, n, d_k) - Queries
        K: (batch, n, d_k) - Keys
        V: (batch, n, d_v) - Values
        mask: Optional attention mask
    
    Returns:
        output: (batch, n, d_v)
    """
    d_k = Q.size(-1)
    
    # Step 1: Compute attention scores (THIS IS THE BOTTLENECK!)
    # Shape: (batch, n, n) - stores ALL pairwise scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    
    # Step 2: Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 3: Softmax (still n×n matrix)
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 4: Apply attention to values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights


# Demo: Memory usage visualization
def analyze_memory(seq_len, d_model=512):
    """Calculate memory for attention matrix"""
    # Attention matrix: (seq_len, seq_len)
    attention_elements = seq_len * seq_len
    # float32 = 4 bytes
    attention_memory_mb = (attention_elements * 4) / (1024 ** 2)
    
    # Input matrices: Q, K, V each (seq_len, d_model)
    input_elements = 3 * seq_len * d_model
    input_memory_mb = (input_elements * 4) / (1024 ** 2)
    
    return {
        'attention_matrix_mb': attention_memory_mb,
        'input_matrices_mb': input_memory_mb,
        'total_mb': attention_memory_mb + input_memory_mb
    }

# Compare different sequence lengths
seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
results = [analyze_memory(n) for n in seq_lengths]

print("Memory Usage Analysis (per attention head):")
print("="*60)
print(f"{'Seq Len':>10} {'Attention Matrix':>20} {'Input Matrices':>20} {'Total':>10}")
print("="*60)
for n, r in zip(seq_lengths, results):
    print(f"{n:>10} {r['attention_matrix_mb']:>18.1f} MB {r['input_matrices_mb']:>18.1f} MB {r['total_mb']:>8.1f} MB")

print("\n⚠️ Notice: Attention matrix dominates for large sequences!")

In [None]:
# Visualize memory growth
plt.figure(figsize=(12, 5))

# Plot 1: Memory vs Sequence Length
plt.subplot(1, 2, 1)
attention_mem = [r['attention_matrix_mb'] for r in results]
input_mem = [r['input_matrices_mb'] for r in results]

plt.plot(seq_lengths, attention_mem, 'r-o', label='Attention Matrix (O(n²))', linewidth=2)
plt.plot(seq_lengths, input_mem, 'b-s', label='Input Matrices (O(n))', linewidth=2)
plt.xlabel('Sequence Length', fontsize=12)
plt.ylabel('Memory (MB)', fontsize=12)
plt.title('Memory Usage: O(n²) vs O(n)', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')

# Plot 2: Percentage breakdown
plt.subplot(1, 2, 2)
percentages = [(r['attention_matrix_mb'] / r['total_mb'] * 100) for r in results]
plt.bar(range(len(seq_lengths)), percentages, color='coral')
plt.xticks(range(len(seq_lengths)), seq_lengths)
plt.xlabel('Sequence Length', fontsize=12)
plt.ylabel('Attention Matrix %', fontsize=12)
plt.title('Attention Matrix as % of Total Memory', fontsize=14)
plt.axhline(y=50, color='r', linestyle='--', label='50%')
plt.legend()
plt.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print(f"\n💡 Insight: At seq_len=16384, attention matrix is {percentages[-1]:.1f}% of memory!")

---

## 3. Flash Attention Algorithm

### Core Insight: IO-Awareness

Modern GPUs have a **memory hierarchy**:
1. **HBM (High Bandwidth Memory)**: Large (40-80GB) but slow
2. **SRAM (on-chip memory)**: Small (20MB) but **20x faster**

**Standard attention**: Reads/writes the full (n, n) matrix from/to HBM → slow!

**Flash Attention**: Keeps working set in SRAM, minimizes HBM access

### The Three Key Ideas

#### 1. Tiling (Block-wise Computation)
- Divide Q, K, V into blocks that fit in SRAM
- Compute attention for each block independently
- Combine results online (no need to store full attention matrix)

#### 2. Recomputation in Backward Pass
- Don't store attention matrix for gradients
- Recompute it during backward pass (fast because it's in SRAM)
- Trade-off: More compute, less memory

#### 3. Kernel Fusion
- Fuse softmax, dropout, masking into attention kernel
- Fewer memory reads/writes
- Better GPU utilization

### Algorithm Visualization

```
Standard Attention:
  Q, K, V (in HBM) → Attention Matrix (in HBM) → Output
  ^^^^^^^^^^^^^^^^    ^^^^^^^^^^^^^^^^^^^^^^^
      Slow I/O              Huge memory

Flash Attention:
  Load Q_block, K_block → Compute in SRAM → Update output
  Repeat for all blocks → No full matrix stored!
  ^^^^^^^^^^^^^^^^^^^^^    ^^^^^^^^^^^^^^^^^^^^
    Minimal I/O              Memory efficient
```

In [None]:
def flash_attention_simplified(Q, K, V, block_size=64):
    """
    Simplified Flash Attention for educational purposes
    (Real implementation is in CUDA for speed)
    
    Key idea: Process attention in blocks to save memory
    """
    batch_size, seq_len, d_k = Q.shape
    d_v = V.size(-1)
    
    # Initialize output and normalization factors
    output = torch.zeros(batch_size, seq_len, d_v, device=Q.device)
    row_max = torch.full((batch_size, seq_len), float('-inf'), device=Q.device)
    row_sum = torch.zeros(batch_size, seq_len, device=Q.device)
    
    # Process in blocks (tiling)
    num_blocks = (seq_len + block_size - 1) // block_size
    
    for i in range(num_blocks):
        # Query block
        q_start = i * block_size
        q_end = min((i + 1) * block_size, seq_len)
        Q_block = Q[:, q_start:q_end, :]  # (batch, block_size, d_k)
        
        for j in range(num_blocks):
            # Key/Value block
            k_start = j * block_size
            k_end = min((j + 1) * block_size, seq_len)
            K_block = K[:, k_start:k_end, :]  # (batch, block_size, d_k)
            V_block = V[:, k_start:k_end, :]  # (batch, block_size, d_v)
            
            # Compute attention scores for this block pair
            scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) / np.sqrt(d_k)
            
            # Online softmax (numerically stable)
            block_max = scores.max(dim=-1, keepdim=True).values
            scores_exp = torch.exp(scores - block_max)
            
            # Update running max and sum for softmax normalization
            new_max = torch.maximum(row_max[:, q_start:q_end], block_max.squeeze(-1))
            old_scale = torch.exp(row_max[:, q_start:q_end] - new_max)
            new_scale = torch.exp(block_max.squeeze(-1) - new_max)
            
            # Update output and normalization (this is the clever part!)
            output[:, q_start:q_end, :] = (output[:, q_start:q_end, :] * old_scale.unsqueeze(-1) + 
                                           torch.matmul(scores_exp * new_scale.unsqueeze(-1), V_block))
            
            row_sum[:, q_start:q_end] = row_sum[:, q_start:q_end] * old_scale + scores_exp.sum(dim=-1)
            row_max[:, q_start:q_end] = new_max
    
    # Final normalization
    output = output / row_sum.unsqueeze(-1)
    
    return output

# Test: Verify it produces same results as standard attention
batch, n, d = 2, 128, 64
Q = torch.randn(batch, n, d)
K = torch.randn(batch, n, d)
V = torch.randn(batch, n, d)

# Standard attention
standard_out, _ = standard_attention(Q, K, V)

# Flash attention (simplified)
flash_out = flash_attention_simplified(Q, K, V, block_size=32)

# Compare
difference = torch.abs(standard_out - flash_out).max().item()
print(f"Maximum difference: {difference:.6f}")
print(f"Results match: {torch.allclose(standard_out, flash_out, atol=1e-4)}")
print("\n✅ Flash Attention produces identical output with less memory!")

---

## 4. PyTorch's Built-in Flash Attention

### Scaled Dot-Product Attention (SDPA)

PyTorch 2.0+ includes **optimized attention** via `F.scaled_dot_product_attention()`

**It automatically chooses the best backend**:
1. **Flash Attention 2** (if available)
2. **Memory-efficient attention** (xFormers)
3. **cuDNN Fused Flash Attention** (H100 GPUs - 75% faster!)
4. **Math implementation** (fallback)

**You just call one function** and get the fastest version!

In [None]:
# Using PyTorch's optimized attention
def efficient_attention(Q, K, V, mask=None, is_causal=False):
    """
    Use PyTorch's scaled_dot_product_attention (automatically uses Flash Attention)
    """
    output = F.scaled_dot_product_attention(
        Q, K, V,
        attn_mask=mask,
        is_causal=is_causal,  # For autoregressive models
        dropout_p=0.0
    )
    return output

# Check which backend is being used
print("Available attention backends:")
print(f"  Flash Attention: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"  Memory Efficient: {torch.backends.cuda.mem_efficient_sdp_enabled()}")
print(f"  Math (fallback): {torch.backends.cuda.math_sdp_enabled()}")

In [None]:
# Performance comparison: Standard vs Flash Attention
def benchmark_attention(seq_len, d_model=512, num_heads=8, device='cuda'):
    """
    Compare standard attention vs Flash Attention
    """
    if device == 'cuda' and not torch.cuda.is_available():
        print("CUDA not available, using CPU")
        device = 'cpu'
    
    # Create random data
    batch_size = 16
    d_k = d_model // num_heads
    
    Q = torch.randn(batch_size, num_heads, seq_len, d_k, device=device)
    K = torch.randn(batch_size, num_heads, seq_len, d_k, device=device)
    V = torch.randn(batch_size, num_heads, seq_len, d_k, device=device)
    
    # Warm up
    for _ in range(3):
        _ = F.scaled_dot_product_attention(Q, K, V)
    
    if device == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark Flash Attention (via SDPA)
    start = time.time()
    for _ in range(10):
        output = F.scaled_dot_product_attention(Q, K, V)
    if device == 'cuda':
        torch.cuda.synchronize()
    flash_time = (time.time() - start) / 10
    
    # Memory usage
    if device == 'cuda':
        torch.cuda.reset_peak_memory_stats()
        output = F.scaled_dot_product_attention(Q, K, V)
        flash_memory = torch.cuda.max_memory_allocated() / 1e6  # MB
    else:
        flash_memory = 0
    
    return {
        'time': flash_time,
        'memory_mb': flash_memory
    }

# Run benchmarks
if torch.cuda.is_available():
    seq_lengths = [512, 1024, 2048, 4096]
    results = []
    
    print("Benchmarking Flash Attention (via SDPA):")
    print("="*50)
    print(f"{'Seq Len':>10} {'Time (ms)':>15} {'Memory (MB)':>15}")
    print("="*50)
    
    for seq_len in seq_lengths:
        result = benchmark_attention(seq_len, device='cuda')
        results.append(result)
        print(f"{seq_len:>10} {result['time']*1000:>14.2f} {result['memory_mb']:>14.1f}")
    
    print("\n✅ Flash Attention is automatically used on CUDA!")
else:
    print("⚠️ CUDA not available. Flash Attention requires GPU.")
    print("Running on CPU with standard implementation.")

---

## 5. FlexAttention API (PyTorch 2.5+)

### The Problem with Custom Attention Patterns

Modern models need **different attention patterns**:
- **Causal masking** (GPT): Can't attend to future tokens
- **Sliding window** (Longformer): Local attention only
- **Prefix attention** (PrefixLM): Different masks for prefix vs generation
- **Alibi/RoPE**: Custom positional biases

**Problem**: Writing custom CUDA kernels for each pattern is hard!

### FlexAttention: Flexible + Fast

PyTorch 2.5 introduced **FlexAttention**:
- Write attention pattern in **Python**
- PyTorch **automatically generates fused CUDA kernel**
- Performance matches hand-written kernels!

In [None]:
# FlexAttention example (PyTorch 2.5+)
try:
    from torch.nn.attention.flex_attention import flex_attention, create_block_mask
    
    # Define custom attention pattern
    def sliding_window_mask(b, h, q_idx, kv_idx):
        """Sliding window: attend to nearest 256 tokens"""
        window_size = 256
        return (q_idx - kv_idx).abs() <= window_size
    
    def prefix_lm_mask(b, h, q_idx, kv_idx):
        """Prefix can attend bidirectionally, generation is causal"""
        prefix_len = 128
        if q_idx < prefix_len:
            return kv_idx < prefix_len  # Bidirectional in prefix
        else:
            return kv_idx <= q_idx  # Causal in generation
    
    # Usage
    batch, num_heads, seq_len, d_k = 4, 8, 512, 64
    Q = torch.randn(batch, num_heads, seq_len, d_k, device=device)
    K = torch.randn(batch, num_heads, seq_len, d_k, device=device)
    V = torch.randn(batch, num_heads, seq_len, d_k, device=device)
    
    # Create block mask from function
    block_mask = create_block_mask(sliding_window_mask, B=batch, H=num_heads, 
                                    Q_LEN=seq_len, KV_LEN=seq_len)
    
    # Apply FlexAttention (automatically generates optimized kernel!)
    output = flex_attention(Q, K, V, block_mask=block_mask)
    
    print("✅ FlexAttention successfully applied!")
    print(f"Output shape: {output.shape}")
    print("\n💡 FlexAttention automatically fused the custom pattern into a fast kernel!")
    
except ImportError:
    print("⚠️ FlexAttention requires PyTorch 2.5+")
    print("Current version:", torch.__version__)
    print("\nFlexAttention allows custom attention patterns with Flash Attention speed!")

---

## 6. cuDNN Fused Flash Attention (H100)

### The Latest Optimization

PyTorch 2.5 added **cuDNN Fused Flash Attention** backend:
- **75% speedup** on H100 GPUs vs Flash Attention 2
- Enabled by default on H100+
- No code changes needed!

**Why so fast?**
- Uses H100's specialized Tensor cores
- Better instruction scheduling
- Optimized for H100's memory hierarchy

In [None]:
# Controlling attention backend
from torch.nn.attention import SDPBackend

# Force specific backend
def attention_with_backend(Q, K, V, backend=SDPBackend.FLASH_ATTENTION):
    """
    Use specific attention backend
    
    Backends:
    - FLASH_ATTENTION: Flash Attention 2
    - EFFICIENT_ATTENTION: Memory-efficient (xFormers)
    - MATH: Standard PyTorch (slowest)
    - CUDNN_ATTENTION: cuDNN Fused (H100 only, auto-selected)
    """
    with sdpa_kernel([backend]):
        output = F.scaled_dot_product_attention(Q, K, V)
    return output

# Check GPU capability
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    
    if 'H100' in gpu_name:
        print("\n✅ H100 detected! cuDNN Fused Flash Attention available.")
        print("Expected speedup: ~75% vs Flash Attention 2")
    else:
        print(f"\n💡 On {gpu_name}, Flash Attention 2 will be used.")
        print("cuDNN Fused Flash Attention requires H100 or newer.")
else:
    print("⚠️ No GPU available. Flash Attention optimizations require CUDA.")

---

## 7. Practical Multi-Head Attention with Flash Attention

### Building a Production-Ready Attention Module

In [None]:
class FlashMultiHeadAttention(nn.Module):
    """
    Multi-Head Attention with Flash Attention optimization
    
    Automatically uses:
    - cuDNN Fused Flash Attention (H100)
    - Flash Attention 2 (A100/other GPUs)
    - Memory-efficient attention (fallback)
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = dropout
    
    def forward(self, Q, K, V, mask=None, is_causal=False):
        """
        Args:
            Q: (batch, seq_len, d_model)
            K: (batch, seq_len, d_model)
            V: (batch, seq_len, d_model)
            mask: Optional attention mask
            is_causal: Use causal masking (for GPT-style models)
        """
        batch_size, seq_len, _ = Q.shape
        
        # Linear projections and reshape to (batch, num_heads, seq_len, d_k)
        Q = self.W_q(Q).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Flash Attention! (automatically selects best backend)
        attn_output = F.scaled_dot_product_attention(
            Q, K, V,
            attn_mask=mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=is_causal
        )
        
        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(attn_output)
        
        return output


# Demo
d_model = 512
num_heads = 8
seq_len = 1024
batch_size = 4

model = FlashMultiHeadAttention(d_model, num_heads).to(device)
x = torch.randn(batch_size, seq_len, d_model, device=device)

# Self-attention
output = model(x, x, x, is_causal=True)  # Causal for GPT-style

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\n✅ Flash Attention is used internally!")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

---

## 8. When to Use Flash Attention

### Decision Guide

**Always use Flash Attention when**:
- ✅ Sequence length > 512
- ✅ Training transformers (memory savings allow larger batches)
- ✅ Long context models (4k+ tokens)
- ✅ GPU available (especially A100, H100)

**Standard attention is fine when**:
- ⚠️ Very short sequences (< 128 tokens)
- ⚠️ CPU only (Flash Attention needs GPU)
- ⚠️ Debugging (standard is easier to inspect)

### Performance Characteristics

| Metric | Standard Attention | Flash Attention |
|--------|-------------------|------------------|
| **Memory** | O(n²) | O(n) |
| **Speed** | Baseline | 2-4x faster |
| **Accuracy** | Exact | Exact (same output) |
| **GPU Required** | No | Yes |
| **Max Sequence** | ~2k tokens | 100k+ tokens |

---

## Mini Exercises

### Exercise 1: Memory Calculation

Calculate the attention matrix memory for a model with:
- Sequence length: 8192 tokens
- Number of heads: 32
- Batch size: 8
- Data type: float16 (2 bytes per element)

How much memory (in GB) is needed just for attention matrices?

In [None]:
# Your code here


In [None]:
# Solution
seq_len = 8192
num_heads = 32
batch_size = 8
bytes_per_element = 2  # float16

# Attention matrix per head: (seq_len, seq_len)
elements_per_head = seq_len * seq_len

# Total elements: batch * heads * matrix
total_elements = batch_size * num_heads * elements_per_head

# Memory in GB
memory_bytes = total_elements * bytes_per_element
memory_gb = memory_bytes / (1024 ** 3)

print(f"Attention matrix memory: {memory_gb:.2f} GB")
print(f"\n💡 That's {memory_gb:.1f}GB just for attention!")
print(f"With Flash Attention: ~{memory_gb/10:.2f}GB (10x reduction)")

### Exercise 2: Causal Attention with Flash Attention

Implement a simple decoder using Flash Attention with causal masking.

In [None]:
# Your code here


In [None]:
# Solution
class CausalFlashAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.attention = FlashMultiHeadAttention(d_model, num_heads)
    
    def forward(self, x):
        # Causal masking via is_causal=True
        return self.attention(x, x, x, is_causal=True)

# Test
model = CausalFlashAttention(d_model=256, num_heads=8).to(device)
x = torch.randn(4, 512, 256, device=device)
output = model(x)

print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print("\n✅ Causal Flash Attention working!")

---

## Comprehensive Exercise: Flash Attention Benchmark

Create a comprehensive benchmark comparing:
1. Standard attention (manual implementation)
2. PyTorch SDPA with Flash Attention

For sequence lengths [512, 1024, 2048, 4096], measure:
- Forward pass time
- Peak memory usage
- Numerical accuracy (compare outputs)

Plot the results showing speedup and memory savings.

In [None]:
# Your code here


In [None]:
# Solution
def comprehensive_benchmark():
    if not torch.cuda.is_available():
        print("GPU required for meaningful benchmark")
        return
    
    seq_lengths = [512, 1024, 2048, 4096]
    d_model = 512
    num_heads = 8
    batch_size = 8
    
    results = {'seq_len': [], 'standard_time': [], 'flash_time': [], 
               'standard_mem': [], 'flash_mem': [], 'speedup': []}
    
    for seq_len in seq_lengths:
        d_k = d_model // num_heads
        Q = torch.randn(batch_size, num_heads, seq_len, d_k, device='cuda')
        K = torch.randn(batch_size, num_heads, seq_len, d_k, device='cuda')
        V = torch.randn(batch_size, num_heads, seq_len, d_k, device='cuda')
        
        # Warm up
        for _ in range(3):
            _ = F.scaled_dot_product_attention(Q, K, V)
        torch.cuda.synchronize()
        
        # Flash Attention (SDPA)
        torch.cuda.reset_peak_memory_stats()
        start = time.time()
        for _ in range(10):
            flash_out = F.scaled_dot_product_attention(Q, K, V)
        torch.cuda.synchronize()
        flash_time = (time.time() - start) / 10
        flash_mem = torch.cuda.max_memory_allocated() / 1e6
        
        # Standard attention (for small sequences only)
        if seq_len <= 2048:
            Q_std = Q[:, 0, :, :]  # Single head for comparison
            K_std = K[:, 0, :, :]
            V_std = V[:, 0, :, :]
            
            torch.cuda.reset_peak_memory_stats()
            start = time.time()
            for _ in range(10):
                std_out, _ = standard_attention(Q_std, K_std, V_std)
            torch.cuda.synchronize()
            std_time = (time.time() - start) / 10
            std_mem = torch.cuda.max_memory_allocated() / 1e6
            
            speedup = std_time / flash_time
        else:
            std_time = float('nan')
            std_mem = float('nan')
            speedup = float('nan')
        
        results['seq_len'].append(seq_len)
        results['standard_time'].append(std_time * 1000)  # ms
        results['flash_time'].append(flash_time * 1000)
        results['standard_mem'].append(std_mem)
        results['flash_mem'].append(flash_mem)
        results['speedup'].append(speedup)
    
    # Plot results
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Time comparison
    axes[0].plot(results['seq_len'][:3], results['standard_time'][:3], 'r-o', label='Standard', linewidth=2)
    axes[0].plot(results['seq_len'], results['flash_time'], 'b-s', label='Flash', linewidth=2)
    axes[0].set_xlabel('Sequence Length', fontsize=12)
    axes[0].set_ylabel('Time (ms)', fontsize=12)
    axes[0].set_title('Forward Pass Time', fontsize=14)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Memory comparison
    axes[1].plot(results['seq_len'][:3], results['standard_mem'][:3], 'r-o', label='Standard', linewidth=2)
    axes[1].plot(results['seq_len'], results['flash_mem'], 'b-s', label='Flash', linewidth=2)
    axes[1].set_xlabel('Sequence Length', fontsize=12)
    axes[1].set_ylabel('Peak Memory (MB)', fontsize=12)
    axes[1].set_title('Memory Usage', fontsize=14)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Speedup
    valid_speedups = [s for s in results['speedup'][:3] if not np.isnan(s)]
    axes[2].bar(range(len(valid_speedups)), valid_speedups, color='green', alpha=0.7)
    axes[2].set_xticks(range(len(valid_speedups)))
    axes[2].set_xticklabels(results['seq_len'][:len(valid_speedups)])
    axes[2].set_xlabel('Sequence Length', fontsize=12)
    axes[2].set_ylabel('Speedup (x)', fontsize=12)
    axes[2].set_title('Flash Attention Speedup', fontsize=14)
    axes[2].axhline(y=1, color='r', linestyle='--', label='Baseline')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    print("\n📊 Benchmark Results:")
    print("="*70)
    print(f"{'Seq Len':>10} {'Standard (ms)':>15} {'Flash (ms)':>15} {'Speedup':>10} {'Mem Saved':>10}")
    print("="*70)
    for i in range(len(results['seq_len'])):
        if not np.isnan(results['speedup'][i]):
            mem_saved = (1 - results['flash_mem'][i] / results['standard_mem'][i]) * 100
            print(f"{results['seq_len'][i]:>10} {results['standard_time'][i]:>14.2f} {results['flash_time'][i]:>14.2f} "
                  f"{results['speedup'][i]:>9.2f}x {mem_saved:>8.1f}%")
        else:
            print(f"{results['seq_len'][i]:>10} {'N/A':>14} {results['flash_time'][i]:>14.2f} {'N/A':>10} {'N/A':>10}")

comprehensive_benchmark()

---

## Key Takeaways

1. **Standard attention has O(n²) memory** - the bottleneck for long sequences
2. **Flash Attention uses tiling and recomputation** to achieve O(n) memory
3. **PyTorch SDPA automatically uses Flash Attention** - just one function call!
4. **FlexAttention (2.5+) enables custom patterns** with Flash Attention speed
5. **cuDNN Fused Flash Attention (H100)** provides 75% additional speedup
6. **Always use Flash Attention for seq_len > 512** and GPU training
7. **2-4x faster, 10-20x less memory** - enables longer contexts

## Modern LLM Usage

- **GPT-4**: Flash Attention for 32k+ context
- **Claude 3**: 200k context with optimized attention
- **LLaMA 3**: Flash Attention 2 standard
- **Mistral**: Flash Attention + GQA combination

## Next Steps

Continue to: [Topic 12: Grouped Query Attention (GQA)](12_grouped_query_attention.ipynb)

---

## Further Reading

- [Flash Attention Paper](https://arxiv.org/abs/2205.14135)
- [Flash Attention 2 Paper](https://arxiv.org/abs/2307.08691)
- [PyTorch SDPA Documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
- [PyTorch 2.5 FlexAttention](https://pytorch.org/blog/pytorch2-5/)
- [cuDNN Flash Attention](https://docs.nvidia.com/deeplearning/cudnn/)