# Attention Variants: MQA, GQA & Sliding Window

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/attention_variants.ipynb)

This notebook implements from scratch the key attention mechanism variants that improve upon standard Multi-Head Attention (MHA):

1. **Multi-Head Attention (MHA)** — the original (baseline)
2. **Multi-Query Attention (MQA)** — all heads share one K,V projection
3. **Grouped-Query Attention (GQA)** — groups of heads share K,V projections
4. **Sliding Window Attention** — each token attends only to a local window

We compare memory usage, KV cache sizes, and attention patterns.

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import matplotlib.pyplot as plt
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 0. Mathematical Foundations

### The KV Cache Problem

During autoregressive generation (generating one token at a time), the model must store the Key and Value projections for all previous tokens — this is the **KV cache**.

For standard Multi-Head Attention:
- KV cache per layer = $2 \times n_{\text{heads}} \times \text{seq\_len} \times d_k$
- For LLaMA-70B: $2 \times 64 \times 8192 \times 128 = 128$ MB per layer (in FP16)
- With 80 layers: **10.2 GB** just for the KV cache!

### Multi-Query Attention (MQA)

All heads share **one** set of K,V projections, but each head still has its own Q projection:

$$Q_h = XW_h^Q, \quad K = XW^K, \quad V = XW^V$$

KV cache reduced by factor of $n_{\text{heads}}$.

### Grouped-Query Attention (GQA)

Compromise: divide heads into $G$ groups. Each group shares one K,V:

$$Q_h = XW_h^Q, \quad K_g = XW_g^K, \quad V_g = XW_g^V \quad (\text{where } g = \lfloor h \cdot G / n_{\text{heads}} \rfloor)$$

KV cache reduced by factor of $n_{\text{heads}} / G$.

| Method | KV heads | KV cache size | Quality |
|--------|----------|---------------|--------|
| MHA | $n_{\text{heads}}$ | $1\times$ | Best |
| GQA | $G$ (e.g., 8) | $G/n_{\text{heads}}$ | Near-MHA |
| MQA | 1 | $1/n_{\text{heads}}$ | Slightly worse |

### Sliding Window Attention

Instead of attending to all positions, each token attends only to the nearest $w$ tokens:

$$\text{mask}_{ij} = \begin{cases} 0 & \text{if } |i - j| \leq w/2 \\ -\infty & \text{otherwise} \end{cases}$$

Complexity drops from $O(n^2)$ to $O(n \cdot w)$. With stacked layers, the effective receptive field grows: layer $L$ can "see" $L \times w$ tokens.

## 1. Multi-Head Attention (MHA) — Baseline

In [None]:
def multi_head_attention(X, W_Q, W_K, W_V, W_O, n_heads):
    """Standard Multi-Head Attention.
    
    Each head has its own Q, K, V projections.
    
    Args:
        X: (batch, seq_len, d_model)
        W_Q, W_K, W_V: (d_model, d_model) — projections for all heads concatenated
        W_O: (d_model, d_model) — output projection
        n_heads: number of attention heads
    """
    batch, seq_len, d_model = X.shape
    d_k = d_model // n_heads
    
    # Project Q, K, V
    Q = X @ W_Q  # (batch, seq_len, d_model)
    K = X @ W_K
    V = X @ W_V
    
    # Reshape to (batch, n_heads, seq_len, d_k)
    Q = Q.view(batch, seq_len, n_heads, d_k).transpose(1, 2)
    K = K.view(batch, seq_len, n_heads, d_k).transpose(1, 2)
    V = V.view(batch, seq_len, n_heads, d_k).transpose(1, 2)
    
    # Scaled dot-product attention
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    weights = torch.softmax(scores, dim=-1)
    attn_out = torch.matmul(weights, V)  # (batch, n_heads, seq_len, d_k)
    
    # Concatenate heads and project
    attn_out = attn_out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
    output = attn_out @ W_O
    
    return output, weights

# Test
torch.manual_seed(42)
batch, seq_len, d_model, n_heads = 2, 8, 32, 4
d_k = d_model // n_heads

X = torch.randn(batch, seq_len, d_model, device=device)
W_Q = torch.randn(d_model, d_model, device=device) * 0.1
W_K = torch.randn(d_model, d_model, device=device) * 0.1
W_V = torch.randn(d_model, d_model, device=device) * 0.1
W_O = torch.randn(d_model, d_model, device=device) * 0.1

out_mha, w_mha = multi_head_attention(X, W_Q, W_K, W_V, W_O, n_heads)
print('MHA output shape:', out_mha.shape)
print('MHA weights shape:', w_mha.shape)

# KV cache size
kv_cache_mha = 2 * n_heads * seq_len * d_k
print(f'\nKV cache entries per layer: {kv_cache_mha}')
print(f'  = 2 (K+V) x {n_heads} heads x {seq_len} tokens x {d_k} dims')

## 2. Multi-Query Attention (MQA)

All heads share **one** K and V projection. Each head still has its own Q.

In [None]:
def multi_query_attention(X, W_Q, W_K, W_V, W_O, n_heads):
    """Multi-Query Attention: all heads share one K,V.
    
    Args:
        X: (batch, seq_len, d_model)
        W_Q: (d_model, d_model) — per-head Q projections concatenated
        W_K: (d_model, d_k) — SINGLE shared K projection
        W_V: (d_model, d_k) — SINGLE shared V projection
        W_O: (d_model, d_model) — output projection
    """
    batch, seq_len, d_model = X.shape
    d_k = d_model // n_heads
    
    # Q: each head has its own projection
    Q = X @ W_Q  # (batch, seq_len, d_model)
    Q = Q.view(batch, seq_len, n_heads, d_k).transpose(1, 2)  # (batch, n_heads, seq_len, d_k)
    
    # K, V: SHARED across all heads
    K = X @ W_K  # (batch, seq_len, d_k)
    V = X @ W_V  # (batch, seq_len, d_k)
    
    # Expand K,V to match Q's head dimension for broadcasting
    K = K.unsqueeze(1)  # (batch, 1, seq_len, d_k) — broadcasts over n_heads
    V = V.unsqueeze(1)  # (batch, 1, seq_len, d_k)
    
    # Attention
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    weights = torch.softmax(scores, dim=-1)
    attn_out = torch.matmul(weights, V)  # (batch, n_heads, seq_len, d_k)
    
    # Concatenate and project
    attn_out = attn_out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
    output = attn_out @ W_O
    
    return output, weights

# Test MQA
torch.manual_seed(42)
W_Q_mqa = torch.randn(d_model, d_model, device=device) * 0.1
W_K_mqa = torch.randn(d_model, d_k, device=device) * 0.1  # Only d_k output dims!
W_V_mqa = torch.randn(d_model, d_k, device=device) * 0.1
W_O_mqa = torch.randn(d_model, d_model, device=device) * 0.1

out_mqa, w_mqa = multi_query_attention(X, W_Q_mqa, W_K_mqa, W_V_mqa, W_O_mqa, n_heads)
print('MQA output shape:', out_mqa.shape)
print('MQA weights shape:', w_mqa.shape)

# KV cache size
kv_cache_mqa = 2 * 1 * seq_len * d_k  # Only 1 KV head
print(f'\nKV cache entries per layer: {kv_cache_mqa}')
print(f'  = 2 (K+V) x 1 KV head x {seq_len} tokens x {d_k} dims')
print(f'  → {n_heads}x smaller than MHA!')

## 3. Grouped-Query Attention (GQA)

GQA is the sweet spot: groups of Q heads share K,V projections. With $G$ groups and $H$ heads, each group has $H/G$ query heads sharing one K,V pair.

- $G = H$ → Standard MHA
- $G = 1$ → MQA
- $1 < G < H$ → GQA (the useful range)

In [None]:
def grouped_query_attention(X, W_Q, W_K, W_V, W_O, n_heads, n_kv_groups):
    """Grouped-Query Attention.
    
    Args:
        X: (batch, seq_len, d_model)
        W_Q: (d_model, d_model) — all Q heads
        W_K: (d_model, n_kv_groups * d_k) — KV for each group
        W_V: (d_model, n_kv_groups * d_k)
        W_O: (d_model, d_model)
        n_heads: number of Q heads
        n_kv_groups: number of KV groups
    """
    batch, seq_len, d_model = X.shape
    d_k = d_model // n_heads
    heads_per_group = n_heads // n_kv_groups
    
    # Q: per-head projections
    Q = X @ W_Q  # (batch, seq_len, d_model)
    Q = Q.view(batch, seq_len, n_heads, d_k).transpose(1, 2)  # (batch, n_heads, seq_len, d_k)
    
    # K, V: per-group projections
    K = X @ W_K  # (batch, seq_len, n_kv_groups * d_k)
    V = X @ W_V
    K = K.view(batch, seq_len, n_kv_groups, d_k).transpose(1, 2)  # (batch, n_kv_groups, seq_len, d_k)
    V = V.view(batch, seq_len, n_kv_groups, d_k).transpose(1, 2)
    
    # Repeat K,V for each head in the group
    # (batch, n_kv_groups, seq_len, d_k) → (batch, n_heads, seq_len, d_k)
    K = K.repeat_interleave(heads_per_group, dim=1)
    V = V.repeat_interleave(heads_per_group, dim=1)
    
    # Standard attention
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    weights = torch.softmax(scores, dim=-1)
    attn_out = torch.matmul(weights, V)
    
    # Concatenate and project
    attn_out = attn_out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
    output = attn_out @ W_O
    
    return output, weights

# Test GQA with 8 heads and 2 KV groups
torch.manual_seed(42)
n_heads_gqa = 8
n_kv_groups = 2
d_model_gqa = 64
d_k_gqa = d_model_gqa // n_heads_gqa  # 8

X_gqa = torch.randn(2, 8, d_model_gqa, device=device)
W_Q_gqa = torch.randn(d_model_gqa, d_model_gqa, device=device) * 0.1
W_K_gqa = torch.randn(d_model_gqa, n_kv_groups * d_k_gqa, device=device) * 0.1
W_V_gqa = torch.randn(d_model_gqa, n_kv_groups * d_k_gqa, device=device) * 0.1
W_O_gqa = torch.randn(d_model_gqa, d_model_gqa, device=device) * 0.1

out_gqa, w_gqa = grouped_query_attention(X_gqa, W_Q_gqa, W_K_gqa, W_V_gqa, W_O_gqa, n_heads_gqa, n_kv_groups)
print('GQA output shape:', out_gqa.shape)
print('GQA weights shape:', w_gqa.shape)
print(f'\n{n_heads_gqa} Q heads, {n_kv_groups} KV groups')
print(f'  → Heads 0-3 share KV group 0')
print(f'  → Heads 4-7 share KV group 1')

# KV cache
kv_cache_gqa = 2 * n_kv_groups * seq_len * d_k_gqa
kv_cache_mha_full = 2 * n_heads_gqa * seq_len * d_k_gqa
print(f'\nKV cache: {kv_cache_gqa} entries ({n_heads_gqa // n_kv_groups}x smaller than MHA)')

In [None]:
# Verify: heads in same group produce identical K,V but different Q
# This means their attention patterns differ even though they read the same values
print('Attention weights comparison within GQA groups:')
print(f'  Heads 0 and 1 (same group): weight diff = {(w_gqa[0,0] - w_gqa[0,1]).abs().mean():.6f}')
print(f'  Heads 0 and 4 (diff groups): weight diff = {(w_gqa[0,0] - w_gqa[0,4]).abs().mean():.6f}')
print('\n→ Same group = same K,V but different Q → different attention patterns')
print('→ Different groups = different K,V AND different Q → more diverse patterns')

## 4. Sliding Window Attention

Instead of attending to all tokens, each token only attends to a fixed window of nearby tokens. This reduces complexity from $O(n^2)$ to $O(n \cdot w)$.

In [None]:
def sliding_window_mask(seq_len, window_size, device=None):
    """Create a sliding window attention mask.
    
    Tokens can only attend to positions within window_size distance.
    Combined with causal masking (can't attend to future).
    
    Returns:
        mask: (seq_len, seq_len) — True where attention is BLOCKED
    """
    # Distance matrix
    positions = torch.arange(seq_len, device=device)
    dist = (positions.unsqueeze(0) - positions.unsqueeze(1))  # (seq_len, seq_len)
    
    # Block: future tokens (causal) OR too-distant past tokens (window)
    causal_mask = dist < 0  # future positions
    window_mask = dist > window_size  # too far in the past
    mask = causal_mask | window_mask
    
    return mask

def sliding_window_attention(X, W_Q, W_K, W_V, W_O, n_heads, window_size):
    """Multi-head attention with sliding window mask."""
    batch, seq_len, d_model = X.shape
    d_k = d_model // n_heads
    
    Q = (X @ W_Q).view(batch, seq_len, n_heads, d_k).transpose(1, 2)
    K = (X @ W_K).view(batch, seq_len, n_heads, d_k).transpose(1, 2)
    V = (X @ W_V).view(batch, seq_len, n_heads, d_k).transpose(1, 2)
    
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Apply sliding window + causal mask
    mask = sliding_window_mask(seq_len, window_size, device=X.device)
    scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
    
    weights = torch.softmax(scores, dim=-1)
    attn_out = torch.matmul(weights, V)
    
    attn_out = attn_out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
    output = attn_out @ W_O
    
    return output, weights

# Test
torch.manual_seed(42)
seq_len_sw = 16
window_size = 4
X_sw = torch.randn(1, seq_len_sw, d_model, device=device)

out_sw, w_sw = sliding_window_attention(X_sw, W_Q, W_K, W_V, W_O, n_heads, window_size)
print(f'Sliding Window Attention (window={window_size}):')
print(f'  Output shape: {out_sw.shape}')
print(f'  Attention weights shape: {w_sw.shape}')

In [None]:
# Visualize masks and attention patterns
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Full causal mask
causal = torch.triu(torch.ones(seq_len_sw, seq_len_sw), diagonal=1).bool()
axes[0].imshow(~causal.cpu().numpy(), cmap='Blues')
axes[0].set_title(f'Full Causal Mask\n(all past positions visible)')
axes[0].set_xlabel('Key position')
axes[0].set_ylabel('Query position')

# Sliding window mask
sw_mask = sliding_window_mask(seq_len_sw, window_size, device=device)
axes[1].imshow(~sw_mask.cpu().numpy(), cmap='Blues')
axes[1].set_title(f'Sliding Window Mask (w={window_size})\n(only nearby past visible)')
axes[1].set_xlabel('Key position')
axes[1].set_ylabel('Query position')

# Actual attention weights
axes[2].imshow(w_sw[0, 0].detach().cpu().numpy(), cmap='Blues')
axes[2].set_title(f'Attention Weights (head 0)\n(sliding window, w={window_size})')
axes[2].set_xlabel('Key position')
axes[2].set_ylabel('Query position')

plt.tight_layout()
plt.show()

### Effective Receptive Field

With sliding window attention, stacking $L$ layers gives an effective receptive field of $L \times w$ tokens. Information propagates through the network layer by layer.

In [None]:
# Visualize how receptive field grows with layers
def compute_receptive_field(n_layers, window_size, seq_len):
    """Simulate information flow through stacked sliding window layers."""
    # Start: token at position (seq_len-1) can only see itself
    visible = torch.zeros(n_layers + 1, seq_len)
    visible[0, seq_len - 1] = 1.0  # token we're tracking
    
    for layer in range(n_layers):
        prev = visible[layer]
        new = prev.clone()
        for pos in range(seq_len):
            if prev[pos] > 0:
                # This position can see window_size tokens back
                start = max(0, pos - window_size)
                new[start:pos + 1] = 1.0
        visible[layer + 1] = new
    
    return visible

n_layers = 4
window_size = 4
vis_seq_len = 32
receptive = compute_receptive_field(n_layers, window_size, vis_seq_len)

fig, axes = plt.subplots(1, n_layers + 1, figsize=(20, 3))
for layer in range(n_layers + 1):
    axes[layer].imshow(receptive[layer].unsqueeze(0).numpy(), cmap='Blues', aspect='auto')
    visible_count = int(receptive[layer].sum().item())
    if layer == 0:
        axes[layer].set_title(f'Input\n({visible_count} pos visible)')
    else:
        axes[layer].set_title(f'After Layer {layer}\n({visible_count} pos visible)')
    axes[layer].set_yticks([])
    axes[layer].set_xlabel('Position')

plt.suptitle(f'Receptive Field Growth: window={window_size}, effective = layers × window', fontsize=13)
plt.tight_layout()
plt.show()

print(f'Window size: {window_size}')
for layer in range(1, n_layers + 1):
    print(f'After {layer} layer(s): can see {int(receptive[layer].sum().item())} of {vis_seq_len} positions')

## 5. Comprehensive Comparison

In [None]:
# Parameter count and KV cache comparison
d_model = 64
n_heads = 8
d_k = d_model // n_heads

print('=' * 75)
print('PARAMETER & MEMORY COMPARISON')
print(f'd_model={d_model}, n_heads={n_heads}, d_k={d_k}')
print('=' * 75)

configs = [
    ('MHA', n_heads, n_heads),
    ('GQA (4 groups)', n_heads, 4),
    ('GQA (2 groups)', n_heads, 2),
    ('MQA', n_heads, 1),
]

print(f'{"Method":<20} {"Q params":<12} {"KV params":<12} {"Total":<12} {"KV cache/token":<15}')
print('-' * 75)

for name, n_q, n_kv in configs:
    q_params = d_model * (n_q * d_k)  # Q projection
    kv_params = d_model * (n_kv * d_k) * 2  # K and V projections
    total = q_params + kv_params + d_model * d_model  # + output projection
    kv_cache = 2 * n_kv * d_k  # per token
    print(f'{name:<20} {q_params:<12} {kv_params:<12} {total:<12} {kv_cache:<15}')

print('-' * 75)
print(f'\nMQA saves {n_heads}x KV cache vs MHA')
print(f'GQA(2) saves {n_heads//2}x KV cache vs MHA')

In [None]:
# Visualize KV cache scaling with sequence length
seq_lengths = [512, 1024, 2048, 4096, 8192, 16384, 32768]
d_model_real = 4096  # realistic LLM size
n_heads_real = 32
d_k_real = d_model_real // n_heads_real
n_layers = 32
bytes_per_param = 2  # FP16

def kv_cache_gb(seq_len, n_kv_heads):
    return 2 * n_kv_heads * seq_len * d_k_real * n_layers * bytes_per_param / (1024**3)

fig, ax = plt.subplots(figsize=(10, 6))

methods_real = [
    ('MHA (32 KV heads)', 32, 'C0'),
    ('GQA (8 KV groups)', 8, 'C1'),
    ('GQA (4 KV groups)', 4, 'C2'),
    ('MQA (1 KV head)', 1, 'C3'),
]

for name, n_kv, color in methods_real:
    cache_sizes = [kv_cache_gb(sl, n_kv) for sl in seq_lengths]
    ax.plot(seq_lengths, cache_sizes, 'o-', linewidth=2, markersize=6, label=name, color=color)

ax.set_xlabel('Sequence Length')
ax.set_ylabel('KV Cache Size (GB)')
ax.set_title(f'KV Cache Memory vs Sequence Length\n(d_model={d_model_real}, {n_layers} layers, FP16)')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xscale('log', base=2)
ax.set_yscale('log', base=2)
plt.tight_layout()
plt.show()

In [None]:
# Complexity comparison for sliding window
seq_lengths_complexity = list(range(64, 2049, 64))
window_size = 256

full_attn_ops = [s * s for s in seq_lengths_complexity]  # O(n^2)
window_attn_ops = [s * min(s, window_size) for s in seq_lengths_complexity]  # O(n*w)

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(seq_lengths_complexity, full_attn_ops, linewidth=2, label='Full Attention O(n²)')
ax.plot(seq_lengths_complexity, window_attn_ops, linewidth=2, label=f'Sliding Window O(n·w), w={window_size}')
ax.set_xlabel('Sequence Length (n)')
ax.set_ylabel('Operations (proportional)')
ax.set_title('Attention Complexity: Full vs Sliding Window')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f'At seq_len=2048 with window={window_size}:')
print(f'  Full attention: {2048*2048:,} ops')
print(f'  Sliding window: {2048*window_size:,} ops')
print(f'  → {2048*2048 / (2048*window_size):.1f}x reduction')

In [None]:
# Summary table
print('=' * 90)
print('COMPARISON: Attention Variants')
print('=' * 90)
print(f'{"Property":<25} {"MHA":<16} {"MQA":<16} {"GQA":<16} {"Sliding Window":<16}')
print('-' * 90)
print(f'{"KV heads":<25} {"H":<16} {"1":<16} {"G (1<G<H)":<16} {"H":<16}')
print(f'{"KV cache":<25} {"2·H·n·d_k":<16} {"2·n·d_k":<16} {"2·G·n·d_k":<16} {"2·H·w·d_k":<16}')
print(f'{"Complexity":<25} {"O(n²)":<16} {"O(n²)":<16} {"O(n²)":<16} {"O(n·w)":<16}')
print(f'{"Quality":<25} {"Best":<16} {"Slightly less":<16} {"Near-MHA":<16} {"Good (local)":<16}')
print(f'{"Speed (inference)":<25} {"1x":<16} {"4-8x":<16} {"~2-4x":<16} {"n/w x":<16}')
print(f'{"Used in":<25} {"Original TF":<16} {"PaLM,Falcon":<16} {"LLaMA2,Mistral":<16} {"Mistral":<16}')
print('=' * 90)

## Summary

In this notebook we implemented from scratch:

1. **MHA (Multi-Head Attention)** — each head has its own Q, K, V projections. Full quality, full KV cache cost.

2. **MQA (Multi-Query Attention)** — all heads share one K, V. Reduces KV cache by $H\times$, slightly lower quality.

3. **GQA (Grouped-Query Attention)** — groups of heads share K, V. The Goldilocks solution: 90-95% of MHA quality with large KV cache savings. **Current best practice** (LLaMA 2/3, Mistral, Gemma).

4. **Sliding Window Attention** — each token attends only to $w$ nearby tokens. $O(n \cdot w)$ instead of $O(n^2)$. Stacking $L$ layers gives $L \times w$ effective receptive field.

**Key insight:** Modern production models often combine these — e.g., Mistral 7B uses both GQA and sliding window attention. The trend is toward maximizing quality per byte of KV cache memory.