# Efficient Attention: Flash Attention & Memory Optimization

Attention layers become the bottleneck when sequence lengths grow. This notebook examines memory-efficient variants—most notably Flash Attention—and shows how to benchmark, mask, and fall back to chunked implementations when kernels are unavailable.

## Learning Objectives

- Diagnose the computational cost of naive attention.
- Use PyTorch 2.x `scaled_dot_product_attention` to leverage Flash/efficient kernels when available.
- Implement chunked attention as a fallback for long sequences.
- Build an attention wrapper that records backend choice and timing.

## Cost of Naive Attention

Standard attention materializes the full `L × L` score matrix, consuming `O(L^2)` memory and compute. Flash Attention algorithms compute softmax in tiles, reducing memory to `O(L)` while keeping results numerically stable.

In [None]:
import torch
import torch.nn.functional as F
import math
import time
import matplotlib.pyplot as plt

torch.manual_seed(0)

def naive_attention(q, k, v, mask=None):
    scores = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return weights @ v

q = torch.randn(2, 8, 128, 64)
k = torch.randn(2, 8, 128, 64)
v = torch.randn(2, 8, 128, 64)

out_naive = naive_attention(q, k, v)
out_flash = F.scaled_dot_product_attention(q, k, v)
print((out_naive - out_flash).abs().max())


### Benchmarking Runtime

On CPU the difference is modest; on GPUs Flash Attention shines. Still, benchmarking locally helps you understand the trade-offs.

In [None]:
def benchmark(fn, *args, iters=20):
    start = time.perf_counter()
    for _ in range(iters):
        fn(*args)
    end = time.perf_counter()
    return (end - start) / iters

naive_time = benchmark(naive_attention, q, k, v)
flash_time = benchmark(F.scaled_dot_product_attention, q, k, v)
print(f"Naive: {naive_time:.6f}s | Flash/efficient: {flash_time:.6f}s (CPU measurement; expect larger gains on GPU)")


### Structured Sparsity Masks

Local attention windows restrict computation to a band around the diagonal—a common trick in long sequence models.

In [None]:
seq_len = 64
radius = 8
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
    mask[i, max(0, i - radius): i + radius + 1] = 1

plt.imshow(mask, cmap="Blues")
plt.title("Local attention mask (radius=8)")
plt.xlabel("Key index")
plt.ylabel("Query index")
plt.show()


## Mini Task – Chunked Attention

Implement attention in chunks along the query dimension to reduce memory usage when Flash kernels are unavailable.

In [None]:
def chunked_attention(q, k, v, chunk_size=32):
    # TODO: compute attention by iterating over query chunks
    raise NotImplementedError


In [None]:
def chunked_attention(q, k, v, chunk_size=32):
    outputs = []
    for start in range(0, q.size(-2), chunk_size):
        end = start + chunk_size
        q_chunk = q[..., start:end, :]
        attn = q_chunk @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
        weights = torch.softmax(attn, dim=-1)
        outputs.append(weights @ v)
    return torch.cat(outputs, dim=-2)

chunked = chunked_attention(q, k, v)
print((chunked - out_naive).abs().max())


## Comprehensive Exercise – Flash Attention Wrapper

Create a module `FlashMHA` that:

- Projects inputs into Q/K/V, runs attention.
- Uses PyTorch's scaled dot-product kernels when available and falls back to chunked or naive attention otherwise.
- Records which backend was used and the elapsed time.

In [None]:
class FlashMHA(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        # TODO: project inputs, choose attention backend, record stats

    def forward(self, x, mask=None):
        raise NotImplementedError


In [None]:
class FlashMHA(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def _reshape(self, tensor):
        bsz, seq_len, _ = tensor.shape
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    def forward(self, x, mask=None):
        bsz, seq_len, _ = x.shape
        q = self._reshape(self.q_proj(x))
        k = self._reshape(self.k_proj(x))
        v = self._reshape(self.v_proj(x))
        attn_mask = mask
        if mask is not None and mask.dim() == 3:
            attn_mask = mask.unsqueeze(1)

        start = time.perf_counter()
        backend = "flash"
        try:
            out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        except RuntimeError:
            backend = "chunked"
            out = chunked_attention(q, k, v)
        elapsed_ms = (time.perf_counter() - start) * 1000

        out = out.transpose(1, 2).contiguous().view(bsz, seq_len, self.embed_dim)
        out = self.out_proj(self.dropout(out))
        stats = {"backend": backend, "elapsed_ms": elapsed_ms}
        return out, stats

flash_mha = FlashMHA(embed_dim=64, num_heads=4)
dummy = torch.randn(2, 32, 64)
out, stats = flash_mha(dummy)
print(out.shape, stats)


## Further Reading

- Dao et al. (2022) “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”
- PyTorch 2.1 release notes for SDPA kernels
- Triton tutorials for custom GPU kernels
- Long-range attention models such as Longformer and Performer