# Flash Attention: Making Transformers Memory Efficient

This notebook demonstrates the key differences between standard attention 
and Flash Attention, focusing on memory efficiency and computational advantages.

## Learning Objectives:
1. Understand memory bottlenecks in standard attention
2. Learn how Flash Attention solves these problems
3. Compare implementations and performance
4. Visualize attention patterns

## Prerequisites:
- Basic understanding of transformer attention
- Familiarity with PyTorch tensors


# Vanilla Transformer: Attention Mechanism Explained

## What is Attention?

**Attention** is a mechanism that allows a model to focus on different parts of the input when processing each element. Think of it like **selective focus** - when you read a sentence, you might pay more attention to certain words based on what you're trying to understand.

### Core Idea
> *"How much should each word pay attention to every other word?"*

---

## The Attention Formula

The fundamental attention mechanism can be expressed as:

```
Attention(Q, K, V) = softmax(QK^T / √d_k) V
```

### Three Key Components:

- **Query (Q)**: *"What am I looking for?"*
- **Key (K)**: *"What information is available?"*  
- **Value (V)**: *"What is the actual content?"*

---

## Step-by-Step Process

### 1. **Compute Similarity Scores**
```
Scores = Q @ K^T / √d_k
```
- Calculate how similar each query is to each key
- Scale by √d_k to prevent softmax saturation
- Results in an [L × L] attention matrix

### 2. **Apply Softmax**
```
Attention_Weights = softmax(Scores)
```
- Convert scores to probabilities
- Each row sums to 1.0
- High scores → high attention weights

### 3. **Weighted Sum of Values**
```
Output = Attention_Weights @ V
```
- Combine values based on attention weights
- Each output is a weighted average of all values

---

## Multi-Head Attention

Instead of one attention mechanism, use **multiple heads** in parallel:

### Why Multiple Heads?
- **Different perspectives**: Each head can focus on different types of relationships
- **Richer representations**: Capture various linguistic patterns simultaneously
- **Increased capacity**: More parameters to learn complex patterns

### Process:
1. **Split** embeddings into H heads: `[B, L, D] → [B, H, L, D/H]`
2. **Compute** attention for each head independently
3. **Concatenate** head outputs: `[B, H, L, D/H] → [B, L, D]`
4. **Project** through final linear layer

---

## Self-Attention vs Cross-Attention

### Self-Attention
- **Q, K, V** all come from the **same sequence**
- *"How do words in this sentence relate to each other?"*
- Used in encoder and decoder blocks

### Cross-Attention
- **Q** from target sequence, **K, V** from source sequence
- *"How does each target word relate to source words?"*
- Used in encoder-decoder architectures (translation, etc.)

---

## Causal (Masked) Attention

For **autoregressive** tasks (language modeling), prevent looking at future tokens:

### Causal Mask:
```
Mask = triu(ones(L, L), diagonal=1)
Scores.masked_fill_(Mask, -∞)
```

### Result:
- Position `i` can only attend to positions `≤ i`
- Creates lower-triangular attention pattern
- Essential for language generation

---

## Memory Complexity: The Problem

### Standard Attention Memory Usage:
- **Attention Matrix**: `O(L²)` - stores all pairwise similarities
- **Major Bottleneck**: For L=4096, needs ~67MB just for attention weights
- **Quadratic Scaling**: Memory grows as sequence length squared

### Example Memory Usage:
| Sequence Length | Attention Matrix Memory |
|-----------------|-------------------------|
| 512             | 1 MB                    |
| 1024            | 4 MB                    |
| 2048            | 16 MB                   |
| 4096            | 64 MB                   |
| 8192            | 256 MB                  |

---

## Attention Intuition

### Example: "The cat sat on the mat"

When processing "sat":
- **High attention** to "cat" (subject-verb relationship)
- **Medium attention** to "mat" (location context)
- **Low attention** to "the" (less semantically important)

### Visualization:
```
    The  cat  sat  on  the  mat
The [0.1][0.1][0.2][0.2][0.2][0.2]
cat [0.1][0.3][0.4][0.1][0.0][0.1]
sat [0.0][0.6][0.2][0.1][0.0][0.1]  ← "sat" pays most attention to "cat"
on  [0.1][0.1][0.1][0.2][0.1][0.4]
the [0.1][0.1][0.1][0.1][0.3][0.3]
mat [0.1][0.2][0.1][0.3][0.1][0.2]
```

---

## Transformer Architecture Context

### Encoder Block:
```
Input → Self-Attention → Add&Norm → Feed-Forward → Add&Norm → Output
```

### Decoder Block:
```
Input → Causal Self-Attention → Add&Norm → Cross-Attention → Add&Norm → Feed-Forward → Add&Norm → Output
```

### Key Components:
- **Residual connections**: `output = layer(input) + input`
- **Layer normalization**: Stabilize training
- **Feed-forward networks**: Position-wise processing
- **Multiple layers**: Stack for increased capacity

---

## Common Variations

### 1. **Scaled Dot-Product Attention** (Standard)
```
Attention(Q,K,V) = softmax(QK^T/√d_k)V
```

### 2. **Multi-Query Attention**
- Share K,V across heads, unique Q per head
- Reduces memory and computation

### 4. **Flash Attention**
- Block-wise computation
- Reduces memory from O(L²) to O(L)

---

## Why Attention Works

### 1. **Flexible Relationships**
- Can model any word-to-word relationship
- Not limited by sequential processing

### 2. **Interpretable Weights**
- Attention weights show what the model "focuses on"
- Useful for debugging and understanding

### 3. **Efficient Parallelization**
- All computations can be done in parallel
- Excellent for modern GPU architectures

### 4. **Universal Approximation**
- With enough heads and layers, can model complex functions
- Foundation for modern language models

---

In [None]:
import torch
import torch.nn.functional as F

def standard_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, 
                      causal: bool = False) -> torch.Tensor:
    """
    Standard attention implementation - stores full attention matrix.
    
    Memory complexity: O(L²) where L is sequence length
    """
    B, H, L, D = Q.shape
    scale = 1.0 / (D ** 0.5)
    
    # Step 1: Compute attention scores [B, H, L, L]
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    
    # Step 2: Apply causal mask if needed
    if causal:
        mask = torch.triu(torch.ones(L, L, device=Q.device, dtype=torch.bool), diagonal=1)
        scores.masked_fill_(mask, float('-inf'))
    
    # Step 3: Softmax - still O(L²) memory
    attn_weights = F.softmax(scores, dim=-1)
    
    # Step 4: Apply to values
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

In [None]:
import pprint
# Test parameters
batch_size = 1
n_heads = 1
seq_len = 4
head_dim = 4

print(f"Test configuration:")
print(f"  Batch size: {batch_size}")
print(f"  Number of heads: {n_heads}")
print(f"  Sequence length: {seq_len}")
print(f"  Head dimension: {head_dim}")
print()

# Create test data
torch.manual_seed(42)  # For reproducible results
Q = torch.randn(batch_size, n_heads, seq_len, head_dim)
K = torch.randn(batch_size, n_heads, seq_len, head_dim)
V = torch.randn(batch_size, n_heads, seq_len, head_dim)

print("Input tensor shapes:")
print(f"  Q: {Q.shape}")
print(f"  K: {K.shape}")
print(f"  V: {V.shape}")
print()

# Test 1: Regular attention (no causal mask)
print("Test 1: Regular Attention")
print("-" * 30)

output, attn_weights = standard_attention(Q, K, V, causal=False)

print(f"SHAPE: {attn_weights.shape}")
pprint.pprint(attn_weights)

# Vanilla Attention Implementation

In [None]:
from config import MODEL_PARAMS
from basic_attention.vanilla_transformer import VanillaTransformer
from tokenizer import SimpleTokenizer
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Model parameters
VOCAB_SIZE = MODEL_PARAMS.get('VOCAB', 1000)
D_MODEL = MODEL_PARAMS.get('D_MODEL', 256)
N_HEADS = MODEL_PARAMS.get('N_HEADS', 8)
N_LAYERS = 1
SEQ_LEN = MODEL_PARAMS.get('SEQ_LEN', 64)

transformer = VanillaTransformer(
        vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        n_heads=N_HEADS,
        n_layers=N_LAYERS,
        max_seq_len=SEQ_LEN,
        dropout=0.1
    ).to(device)
    
print(f"Model parameters: {sum(p.numel() for p in transformer.parameters()):,}")

# Test data
tokenizer = SimpleTokenizer(VOCAB_SIZE)
text = "Hello world, welcome to this lecture series! -> Alexy"
print(f"\nInput text: '{text}'")

# Step 1: Tokenize
token_ids = tokenizer.encode(text)
print(f"Token IDs: {token_ids}")

# Step 2: Pad sequence to fixed length
if len(token_ids) < SEQ_LEN:
    token_ids.extend([tokenizer.pad_token] * (SEQ_LEN - len(token_ids)))
else:
    token_ids = token_ids[:SEQ_LEN]

# Step 3: Convert to tensor and add batch dimension
tokens = torch.tensor([token_ids], device=device)  # [1, SEQ_LEN]

with torch.no_grad():
    # Regular attention
    context = transformer(tokens, causal=False)
    print(f"Output shape: {context.shape}")
    pprint.pprint(f"Context Vector {context}")


# Flash Attention Implementation
## Blog Reference
- https://christianjmills.com/posts/cuda-mode-notes/lecture-012/
- https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
- https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad
- https://github.com/Dao-AILab/flash-attention

# Flash Attention: How does it work?
## The Problem

**Standard attention** computes and stores the full attention matrix in memory:

```
Attention(Q,K,V) = softmax(QK^T/√d_k)V
```

- **Memory bottleneck**: O(L²) for attention matrix storage
- **Example**: L=4096 → ~1GB just for attention weights
- **Result**: Can't scale to long sequences

---

## The Flash Attention Solution

**Key insight**: Don't store the full attention matrix. Compute it in blocks.

### Core Innovation: Block-wise Computation

1. **Tile the computation**: Process attention in small blocks
2. **Online softmax**: Update statistics incrementally  
3. **Never materialize**: Full attention matrix never stored
4. **Memory reduction**: O(L²) → O(L)

### Algorithm Overview

```python
# Instead of:
scores = Q @ K^T              # O(L²) memory
weights = softmax(scores)     # O(L²) memory  
output = weights @ V          # Standard

# Flash Attention does:
for query_block in Q:
    for key_block, value_block in zip(K, V):
        # Compute small block-wise attention
        # Update running statistics (online softmax)
        # Accumulate partial outputs
        # Never store full matrix!
```

---

## Mathematical Equivalence

**Crucial property**: Flash Attention produces **identical results** to standard attention.

- **Same softmax formula**
- **Same attention weights** (computed differently)
- **Same final output** (mathematically proven)
- **Zero approximation** - exact computation

---

## Memory Comparison

| Sequence Length | Standard Attention | Flash Attention | Reduction |
|-----------------|-------------------|-----------------|-----------|
| 1,024          | 64 MB             | 4 MB            | 16x       |
| 4,096          | 1 GB              | 16 MB           | 64x       |
| 8,192          | 4 GB              | 32 MB           | 128x      |

---

## The Online Softmax Trick

**Challenge**: How to compute softmax without storing all values?

**Solution**: Update running statistics incrementally:

```python
# For each new block of scores:
m_new = max(m_old, scores.max())           # Update max
exp_scores = exp(scores - m_new)           # Rescale
l_new = exp(m_old - m_new) * l_old + exp_scores.sum()  # Update sum

# Final: softmax = exp(scores - m_new) / l_new
```

This allows computing exact softmax probabilities without storing the full matrix.

---

## Implementation Strategy

### Two-Pass Algorithm:

**Pass 1**: Compute row-wise statistics
- Find maximum score per query
- Compute row-wise normalizing constants

**Pass 2**: Compute actual output
- Recompute scores block-by-block
- Apply exact softmax using Pass 1 statistics
- Accumulate weighted values

### Block Size Tuning:
- **Larger blocks**: Better compute efficiency
- **Smaller blocks**: Lower memory usage
- **Typical choice**: 64x64 or 128x128 blocks

---

## Key Benefits

### ✅ **Memory Efficiency**
- **O(L) instead of O(L²)** memory usage
- Enables much longer sequences
- Better GPU memory utilization

### ✅ **Speed Improvements**
- **Memory-bound → compute-bound**
- Better cache locality
- Often faster despite more operations

### ✅ **Mathematical Exactness**
- **Zero approximation error**
- Drop-in replacement for standard attention
- Same gradients, same training dynamics

### ✅ **Scalability**
- **Enables modern LLMs** (GPT-4, LLaMA, etc.)
- Longer context windows
- Larger batch sizes

---

In [None]:
import torch
from tokenizer import SimpleTokenizer
from config import MODEL_PARAMS
from flash_attention.flash_attention_transformer import FlashAttentionTransformer


torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Model parameters
VOCAB_SIZE = MODEL_PARAMS.get('VOCAB', 1000)
D_MODEL = MODEL_PARAMS.get('D_MODEL', 256)
N_HEADS = MODEL_PARAMS.get('N_HEADS', 8)
N_LAYERS = 1  # Multiple layers
SEQ_LEN = MODEL_PARAMS.get('SEQ_LEN', 64)


# Create complete Flash Attention transformer
transformer = FlashAttentionTransformer(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    dropout=0.1,
    debug=True  # Set to True to see attention visualizations
).to(device)

print(f"Model parameters: {sum(p.numel() for p in transformer.parameters()):,}")
print(f"Model device: {next(transformer.parameters()).device}")

# Test data
tokenizer = SimpleTokenizer(VOCAB_SIZE)
text = "Hello world, welcome to this lecture series from Nice Intelligence"
print(f"\nInput text: '{text}'")

# Tokenize and prepare
token_ids = tokenizer.encode(text)
print(f"Token IDs: {token_ids}")

# Pad to sequence length
if len(token_ids) < SEQ_LEN:
    token_ids.extend([tokenizer.pad_token] * (SEQ_LEN - len(token_ids)))
else:
    token_ids = token_ids[:SEQ_LEN]

tokens = torch.tensor([token_ids], device=device)  # [1, SEQ_LEN]


# Forward pass
with torch.no_grad():
    context = transformer(tokens, causal=False)
    print(f"Context shape: {context.shape}")
    print(f"Context dtype: {context.dtype}")
    print(f"Sample context: {context[0, 0, :6]}")

# Benchmarking Flash Attention
- https://colab.research.google.com/drive/1NgoOfJbQRGOsJh1aBOghj5rTn6m2XhnS#scrollTo=oX42RVTkCiz8