# Sparse Attention

## Paper Reference

**Child, R., Gray, S., Radford, A., & Sutskever, I. (2019).** *Generating Long Sequences with Sparse Transformers.* [arXiv:1904.10509](https://arxiv.org/abs/1904.10509)

---

## Mathematical Derivation

### Motivation

Standard attention has $O(n^2)$ complexity. For long sequences, this becomes prohibitive. Sparse attention reduces this by attending to only a subset of positions.

### Sparsity Patterns

**1. Strided Pattern**: Position $i$ attends to positions $\{i-k, i-2k, ..., 0\}$ plus local context.

**2. Fixed Pattern**: Position $i$ attends to a fixed block containing $i$.

**3. Combined**: Alternating or combined patterns.

### Complexity

- Time: $O(n \cdot \sqrt{n} \cdot d)$ with block size $\sqrt{n}$
- Space: $O(n \cdot \sqrt{n})$

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
def create_sparse_mask(seq_len: int, block_size: int, pattern: str = 'combined'):
    """Create sparse attention mask."""
    mask = torch.ones(seq_len, seq_len, dtype=torch.bool)
    
    if pattern in ('fixed', 'combined'):
        # Local block attention
        for i in range(seq_len):
            block_start = (i // block_size) * block_size
            block_end = min(block_start + block_size, seq_len)
            mask[i, block_start:block_end] = False
    
    if pattern in ('strided', 'combined'):
        # Strided attention
        for i in range(seq_len):
            for j in range(0, seq_len, block_size):
                mask[i, j] = False
    
    return mask

In [None]:
# Visualize different sparsity patterns
seq_len, block_size = 16, 4
patterns = ['fixed', 'strided', 'combined']

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, pattern in enumerate(patterns):
    mask = create_sparse_mask(seq_len, block_size, pattern)
    attention_pattern = (~mask).float().numpy()
    
    sns.heatmap(attention_pattern, ax=axes[idx], cmap='Blues', square=True, cbar=False)
    axes[idx].set_title(f'{pattern.capitalize()} Pattern')
    axes[idx].set_xlabel('Key Position')
    axes[idx].set_ylabel('Query Position')

plt.suptitle(f'Sparse Attention Patterns (block_size={block_size})')
plt.tight_layout()
plt.show()

In [None]:
# Count attended positions
for pattern in patterns:
    mask = create_sparse_mask(seq_len, block_size, pattern)
    attended = (~mask).sum().item()
    full = seq_len * seq_len
    print(f"{pattern}: {attended}/{full} positions ({100*attended/full:.1f}%)")

## When to Use Sparse Attention

| Scenario | Sparse Attention? |
|----------|------------------|
| Sequence > 1024 | Consider it |
| Sequence > 4096 | Recommended |
| Need global context | Use global tokens |
| Local patterns sufficient | Use fixed pattern |