# Longformer from Scratch (Sparse Attention)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/longformer.ipynb)

Longformer scales Transformers to long sequences (e.g., 4096 tokens) by replacing $O(N^2)$ full attention with **sparse attention** patterns.

It combines three local attention patterns:
1. **Sliding Window:** Attend to fixed # of neighbors.
2. **Dilated Sliding Window:** Attend to neighbors with gaps.
3. **Global Attention:** Specific tokens (like `[CLS]`) attend to everything, and everything attends to them.

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Sliding Window Attention (Naive Implementation)

Efficient implementations require custom CUDA kernels. Here we simulate it by masking the full attention matrix.

In [None]:
def create_sliding_window_mask(seq_len, window_size, device):
    """Creates a mask where each token attends to W/2 neighbors on each side."""
    # Standard attention mask is 1s everywhere
    attn_mask = torch.ones((seq_len, seq_len), device=device)
    
    # Create band mask
    # abs(i - j) <= w/2
    rows = torch.arange(seq_len, device=device)[:, None]
    cols = torch.arange(seq_len, device=device)[None, :]
    dist = torch.abs(rows - cols)
    
    mask = (dist <= (window_size // 2)).float()
    
    # Use -inf for masked (0) areas
    return mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0)

window_size = 4
mask = create_sliding_window_mask(10, window_size, device)

print(f"Sliding Window Mask (Window={window_size}):")
print(mask)

## 2. Global Attention (for [CLS])

We add global attention indices. Typically index 0 (`[CLS]`) is global.

In [None]:
def add_global_attention(mask, global_indices):
    """Updates mask so global indices attend to all, and all attend to them."""
    for idx in global_indices:
        mask[idx, :] = 0.0  # Global token sees everything
        mask[:, idx] = 0.0  # Everything sees global token
    return mask

mask = add_global_attention(mask, [0])
print("Sliding Window + Global Attention at [0]:")
print(mask)

## 3. Longformer Self-Attention Layer

Combines local (sliding window) and global attention.

In [None]:
class LongformerSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, window_size):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.window_size = window_size
        
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        # Global projections (often separate parameters in Longformer)
        self.query_global = nn.Linear(d_model, d_model)
        self.key_global = nn.Linear(d_model, d_model)
        self.value_global = nn.Linear(d_model, d_model)
        
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, global_mask):
        # x: [batch, seq_len, d_model]
        # global_mask: [batch, seq_len] (1 if global, 0 if local)
        
        B, T, D = x.shape
        
        # 1. Local Attention (Sliding Window)
        q = self.query(x).view(B, T, self.n_heads, -1).transpose(1, 2)
        k = self.key(x).view(B, T, self.n_heads, -1).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_heads, -1).transpose(1, 2)
        
        # Create sliding window mask (naive)
        local_mask = create_sliding_window_mask(T, self.window_size, x.device)
        
        scores_local = (q @ k.transpose(-2, -1)) / math.sqrt(D // self.n_heads)
        scores_local = scores_local + local_mask
        attn_local = torch.softmax(scores_local, dim=-1)
        out_local = attn_local @ v
        
        # 2. Global Attention
        # For simplicity, we compute full attention for global tokens and merge
        # Real Longformer implements this efficiently without full matrix
        
        q_g = self.query_global(x).view(B, T, self.n_heads, -1).transpose(1, 2)
        k_g = self.key_global(x).view(B, T, self.n_heads, -1).transpose(1, 2)
        v_g = self.value_global(x).view(B, T, self.n_heads, -1).transpose(1, 2)
        
        # Full attention scores
        scores_global = (q_g @ k_g.transpose(-2, -1)) / math.sqrt(D // self.n_heads)
        attn_global = torch.softmax(scores_global, dim=-1)
        out_global = attn_global @ v_g
        
        # 3. Merge
        # If a token is global, take global output. Else, take local.
        # global_mask: [batch, len] -> expand to [batch, heads, len, dim]
        mask = global_mask.view(B, 1, T, 1).expand(-1, self.n_heads, -1, D // self.n_heads)
        
        out = torch.where(mask > 0, out_global, out_local)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        
        return self.out(out)

In [None]:
# Test
model = LongformerSelfAttention(d_model=64, n_heads=4, window_size=4).to(device)
x = torch.randn(1, 10, 64, device=device)
g_mask = torch.zeros(1, 10, device=device)
g_mask[0, 0] = 1 # [CLS] is global

out = model(x, g_mask)
print(f"Longformer Attention Output: {out.shape}")