# Linear Attention

## Paper Reference

**Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020).** *Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention.* ICML. [arXiv:2006.16236](https://arxiv.org/abs/2006.16236)

---

## Mathematical Derivation

### Key Insight

Standard attention: $\text{softmax}(QK^T)V$ requires materializing the $n \times n$ matrix.

Linear attention uses a kernel trick. If we can write:

$$\text{sim}(q, k) = \phi(q)^T \phi(k)$$

Then:

$$\text{Attention}_i = \frac{\sum_j \phi(q_i)^T \phi(k_j) v_j}{\sum_j \phi(q_i)^T \phi(k_j)}$$

$$= \frac{\phi(q_i)^T \sum_j \phi(k_j) v_j^T}{\phi(q_i)^T \sum_j \phi(k_j)}$$

### Order of Operations

Standard: $(QK^T)V \rightarrow O(n^2 d)$

Linear: $Q(K^T V) \rightarrow O(n d^2)$

### Feature Maps

- **ELU+1**: $\phi(x) = \text{ELU}(x) + 1$ (ensures positivity)
- **ReLU**: $\phi(x) = \text{ReLU}(x)$

### Complexity

- Time: $O(n \cdot d^2)$
- Space: $O(n \cdot d + d^2)$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
def elu_feature_map(x: torch.Tensor) -> torch.Tensor:
    """ELU + 1 feature map for linear attention."""
    return F.elu(x) + 1

def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, eps: float = 1e-6):
    """Compute linear attention with O(n*d^2) complexity."""
    # Apply feature map
    q = elu_feature_map(q)
    k = elu_feature_map(k)
    
    # K^T @ V: (d_k, d_v) - aggregated key-value
    kv = torch.einsum('bnd,bnv->bdv', k, v)
    
    # Q @ (K^T @ V): (batch, n, d_v)
    numerator = torch.einsum('bnd,bdv->bnv', q, kv)
    
    # Normalization
    k_sum = k.sum(dim=1)  # (batch, d_k)
    denominator = torch.einsum('bnd,bd->bn', q, k_sum).unsqueeze(-1).clamp(min=eps)
    
    return numerator / denominator

In [None]:
# Compare complexity
import time

def standard_attention(q, k, v):
    scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, v)

seq_lengths = [128, 256, 512, 1024, 2048]
d_model = 64
batch_size = 4

standard_times = []
linear_times = []

for seq_len in seq_lengths:
    q = torch.randn(batch_size, seq_len, d_model)
    k = torch.randn(batch_size, seq_len, d_model)
    v = torch.randn(batch_size, seq_len, d_model)
    
    start = time.perf_counter()
    for _ in range(10):
        _ = standard_attention(q, k, v)
    standard_times.append((time.perf_counter() - start) / 10 * 1000)
    
    start = time.perf_counter()
    for _ in range(10):
        _ = linear_attention(q, k, v)
    linear_times.append((time.perf_counter() - start) / 10 * 1000)

plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, standard_times, 'o-', label='Standard O(n^2 d)')
plt.plot(seq_lengths, linear_times, 's-', label='Linear O(n d^2)')
plt.xlabel('Sequence Length')
plt.ylabel('Time (ms)')
plt.title('Standard vs Linear Attention Scaling')
plt.legend()
plt.grid(True)
plt.show()

## Trade-offs

| Aspect | Standard | Linear |
|--------|----------|--------|
| Time | $O(n^2 d)$ | $O(n d^2)$ |
| Better when | $n < d$ | $n > d$ |
| Attention quality | Exact softmax | Approximate |
| Causal support | Mask-based | RNN-style |