# Flash Attention

## Paper Reference

**Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Re, C. (2022).** *FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.* NeurIPS. [arXiv:2205.14135](https://arxiv.org/abs/2205.14135)

---

## Key Insight

Standard attention is **memory-bound**, not compute-bound. The bottleneck is reading/writing the $O(n^2)$ attention matrix to GPU HBM (High Bandwidth Memory).

### Memory Hierarchy

- **SRAM** (on-chip): ~20MB, ~19TB/s bandwidth
- **HBM** (GPU memory): ~40GB, ~1.5TB/s bandwidth

### FlashAttention Strategy

1. **Tiling**: Process attention in blocks that fit in SRAM
2. **Recomputation**: Don't store $O(n^2)$ matrix, recompute in backward
3. **Online Softmax**: Compute softmax incrementally across blocks

### Algorithm

```
For each block of Q:
    Load Q_block to SRAM
    Initialize running max m, sum l, output O
    For each block of K, V:
        Load K_block, V_block to SRAM
        Compute S_block = Q_block @ K_block^T
        Update m, l, O using online softmax
    Write O to HBM
```

### Complexity

- Time: $O(n^2 d)$ (same as standard)
- Space: $O(n)$ instead of $O(n^2)$
- IO: Reduced by factor of $O(n / \text{block\_size})$

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
def online_softmax_attention_block(q_block, k, v, block_size_kv=64):
    """Demonstrate online softmax for FlashAttention."""
    seq_len_k = k.size(1)
    d_k = q_block.size(-1)
    block_size_q = q_block.size(1)
    
    # Initialize accumulators
    output = torch.zeros_like(q_block)
    max_scores = torch.full((q_block.size(0), block_size_q, 1), float('-inf'))
    sum_exp = torch.zeros((q_block.size(0), block_size_q, 1))
    
    # Process K, V in blocks
    for kv_start in range(0, seq_len_k, block_size_kv):
        kv_end = min(kv_start + block_size_kv, seq_len_k)
        k_block = k[:, kv_start:kv_end]
        v_block = v[:, kv_start:kv_end]
        
        # Compute scores for this block
        scores = torch.matmul(q_block, k_block.transpose(-2, -1)) / (d_k ** 0.5)
        
        # Online softmax update
        block_max = scores.max(dim=-1, keepdim=True).values
        new_max = torch.maximum(max_scores, block_max)
        
        # Rescale previous values
        exp_diff = torch.exp(max_scores - new_max)
        sum_exp = sum_exp * exp_diff
        output = output * exp_diff
        
        # Add new block contribution
        exp_scores = torch.exp(scores - new_max)
        sum_exp = sum_exp + exp_scores.sum(dim=-1, keepdim=True)
        output = output + torch.matmul(exp_scores, v_block)
        
        max_scores = new_max
    
    return output / sum_exp.clamp(min=1e-6)

In [None]:
# Verify correctness
torch.manual_seed(42)
batch, seq_len, d = 2, 64, 32

q = torch.randn(batch, seq_len, d)
k = torch.randn(batch, seq_len, d)
v = torch.randn(batch, seq_len, d)

# Standard attention
scores = torch.matmul(q, k.transpose(-2, -1)) / (d ** 0.5)
weights = F.softmax(scores, dim=-1)
standard_output = torch.matmul(weights, v)

# Online softmax (Flash-style)
flash_output = online_softmax_attention_block(q, k, v, block_size_kv=16)

print(f"Max absolute difference: {(standard_output - flash_output).abs().max():.6f}")
print(f"Outputs match: {torch.allclose(standard_output, flash_output, atol=1e-5)}")

In [None]:
# PyTorch 2.0+ native flash attention
if hasattr(F, 'scaled_dot_product_attention'):
    print("PyTorch SDPA available (includes Flash Attention backend)")
    
    # This automatically uses Flash Attention when possible
    q_t = q.unsqueeze(1)  # Add head dimension
    k_t = k.unsqueeze(1)
    v_t = v.unsqueeze(1)
    
    output = F.scaled_dot_product_attention(q_t, k_t, v_t)
    print(f"SDPA output shape: {output.squeeze(1).shape}")

## Memory Savings

| Method | Attention Memory | Total Memory |
|--------|-----------------|-------------|
| Standard | $O(n^2)$ | $O(n^2 + nd)$ |
| FlashAttention | $O(1)$ per block | $O(n)$ |

For sequence length 2048 with float16:
- Standard: 2048 x 2048 x 2 bytes = 8 MB per head
- Flash: Does not materialize full matrix