# Sliding Window Attention

## Paper Reference

**Beltagy, I., Peters, M. E., & Cohan, A. (2020).** *Longformer: The Long-Document Transformer.* [arXiv:2004.05150](https://arxiv.org/abs/2004.05150)

---

## Key Insight

Not all tasks require full global attention. For many NLP tasks, local context is sufficient. Sliding window attention restricts each position to attend only within a fixed window.

### Formula

For position $i$ with window size $w$:

$$\text{Attend}(i) = \{j : |i - j| \leq w\}$$

### Effective Receptive Field

With $L$ layers and window size $w$, the receptive field grows to $L \times w$ through stacking.

### Dilated Windows

Longformer also supports dilated sliding windows where gaps are introduced:

$$\text{Attend}_{\text{dilated}}(i) = \{j : |i - j| \leq w \cdot d, (i-j) \mod d = 0\}$$

### Complexity

- Time: $O(n \cdot w \cdot d)$
- Space: $O(n \cdot w)$

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
def create_sliding_window_mask(seq_len: int, window_size: int):
    """Create sliding window attention mask."""
    positions = torch.arange(seq_len)
    diff = positions.unsqueeze(0) - positions.unsqueeze(1)
    mask = torch.abs(diff) > window_size
    return mask

class SlidingWindowAttention(nn.Module):
    """Sliding Window Attention for local context modeling."""
    
    def __init__(self, d_model: int, num_heads: int = 8, window_size: int = 256) -> None:
        super().__init__()
        assert d_model % num_heads == 0
        
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.window_size = window_size
        self.scale = 1.0 / math.sqrt(self.d_k)
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
    
    def forward(self, x: torch.Tensor):
        batch, seq_len = x.size(0), x.size(1)
        
        q = self.w_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        # Apply sliding window mask
        window_mask = create_sliding_window_mask(seq_len, self.window_size).to(x.device)
        scores = scores.masked_fill(window_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)
        
        attended = torch.matmul(attn, v)
        attended = attended.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        
        return self.w_o(attended), attn

In [None]:
# Visualize sliding window patterns
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
seq_len = 16

for idx, window in enumerate([2, 4, 8]):
    mask = create_sliding_window_mask(seq_len, window)
    pattern = (~mask).float().numpy()
    
    sns.heatmap(pattern, ax=axes[idx], cmap='Blues', square=True, cbar=False)
    axes[idx].set_title(f'Window Size = {window}')
    axes[idx].set_xlabel('Key Position')
    axes[idx].set_ylabel('Query Position')

plt.suptitle('Sliding Window Attention Patterns')
plt.tight_layout()
plt.show()

In [None]:
# Complexity comparison
seq_lengths = [512, 1024, 2048, 4096, 8192]
window_size = 256
d = 64

print(f"Complexity comparison (window={window_size}):")
print("-" * 60)
print(f"{'Seq Len':<10} {'Full O(n^2*d)':<20} {'Window O(n*w*d)':<20}")
print("-" * 60)

for n in seq_lengths:
    full = n * n * d
    windowed = n * window_size * d
    print(f"{n:<10} {full:<20,} {windowed:<20,} ({full/windowed:.1f}x)")

## When to Use Sliding Window

| Task | Sliding Window? |
|------|----------------|
| Long documents | Yes |
| Local patterns (NER, POS) | Yes |
| Global reasoning needed | Add global tokens |
| Short sequences (<512) | Use full attention |