# Linear Attention Mechanisms: Efficient Alternatives to Quadratic Attention

## Overview

This notebook explores **Linear Attention Mechanisms** - efficient alternatives to the standard quadratic attention found in Transformers. We'll implement and compare various approaches including Performer, LinFormer, and other linear attention variants.

### Key Topics Covered:
1. **Background**: Understanding the quadratic complexity problem
2. **Mathematical Foundations**: How linear attention works
3. **Implementations**: Performer, LinFormer, and other variants
4. **Comparisons**: Attention quality vs computational complexity
5. **Scaling Analysis**: Performance characteristics
6. **Visualizations**: Attention pattern analysis

### Trade-offs We'll Explore:
- **Attention Quality** vs **Computational Complexity**
- **Memory Usage** vs **Approximation Accuracy**
- **Training Speed** vs **Model Performance**

In [1]:
# Import required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import math
from typing import Optional, Tuple
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

Environment setup complete!
PyTorch version: 2.7.1+cu126
Device available: CUDA


## 1. Background: The Quadratic Attention Problem

### Standard Self-Attention

The standard self-attention mechanism in Transformers computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q \in \mathbb{R}^{n \times d_k}$ (queries)
- $K \in \mathbb{R}^{n \times d_k}$ (keys)  
- $V \in \mathbb{R}^{n \times d_v}$ (values)
- $n$ is the sequence length

### The Problem: O(n²) Complexity

Computing $QK^T$ requires $O(n^2 d_k)$ operations, making it prohibitive for long sequences.

### Memory Usage
The attention matrix requires $O(n^2)$ memory, which becomes massive for long sequences.

In [2]:
class StandardAttention(nn.Module):
    """Standard O(n²) self-attention mechanism."""
    
    def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, seq_len, d_model = x.shape
        
        # Linear projections
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Attention computation - O(n²) complexity
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        out = torch.matmul(attention_weights, V)
        
        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out, attention_weights

# Demonstrate complexity
def analyze_complexity(seq_lengths, d_model=512):
    """Analyze time and memory complexity of standard attention."""
    results = {'seq_len': [], 'time': [], 'memory': []}
    
    attention = StandardAttention(d_model)
    
    for seq_len in seq_lengths:
        x = torch.randn(1, seq_len, d_model)
        
        # Measure time
        start_time = time.time()
        with torch.no_grad():
            out, weights = attention(x)
        end_time = time.time()
        
        # Calculate theoretical memory for attention matrix
        attention_memory = seq_len ** 2 * 4  # 4 bytes per float32
        
        results['seq_len'].append(seq_len)
        results['time'].append(end_time - start_time)
        results['memory'].append(attention_memory)
    
    return results

# Analyze complexity for different sequence lengths
seq_lengths = [128, 256, 512, 1024, 2048]
complexity_results = analyze_complexity(seq_lengths)

print("Standard Attention Complexity Analysis:")
print("Seq Length | Time (s) | Memory (MB)")
print("-" * 40)
for i, seq_len in enumerate(complexity_results['seq_len']):
    time_ms = complexity_results['time'][i] * 1000
    memory_mb = complexity_results['memory'][i] / (1024 * 1024)
    print(f"{seq_len:8d} | {time_ms:6.2f}ms | {memory_mb:8.2f}MB")

Standard Attention Complexity Analysis:
Seq Length | Time (s) | Memory (MB)
----------------------------------------
     128 |   4.14ms |     0.06MB
     256 |   6.19ms |     0.25MB
     512 |  23.08ms |     1.00MB
    1024 |  99.74ms |     4.00MB
    2048 | 337.77ms |    16.00MB


## 2. Linear Attention: Mathematical Foundation

### Key Insight: Kernel Trick

Linear attention methods reformulate attention using the **kernel trick**:

$$\text{Attention}(Q, K, V) = \phi(Q) \left( \phi(K)^T V \right)$$

Where $\phi: \mathbb{R}^{d_k} \rightarrow \mathbb{R}^{d_\phi}$ is a feature mapping function.

### Complexity Reduction

Instead of computing $O(n^2)$ attention matrix:
1. Compute $\phi(K)^T V$ first: $O(n d_\phi d_v)$
2. Then multiply by $\phi(Q)$: $O(n d_\phi d_v)$

**Total complexity: $O(n d_\phi d_v)$ - Linear in sequence length!**

In [3]:
class LinearAttentionBase(nn.Module):
    """Base class for linear attention mechanisms."""
    
    def __init__(self, d_model: int, n_heads: int = 8, feature_dim: int = 256):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.feature_dim = feature_dim
        
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
    
    def feature_map(self, x: torch.Tensor) -> torch.Tensor:
        """Feature mapping function φ(x). To be implemented by subclasses."""
        raise NotImplementedError
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape
        
        # Linear projections
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply feature mapping
        Q_feat = self.feature_map(Q)  # [batch, heads, seq_len, feature_dim]
        K_feat = self.feature_map(K)  # [batch, heads, seq_len, feature_dim]
        
        # Linear attention computation: φ(Q) * (φ(K)^T * V)
        # Step 1: Compute φ(K)^T * V
        KV = torch.matmul(K_feat.transpose(-2, -1), V)  # [batch, heads, feature_dim, d_k]
        
        # Step 2: Compute φ(Q) * (φ(K)^T * V)
        out = torch.matmul(Q_feat, KV)  # [batch, heads, seq_len, d_k]
        
        # Normalization (important for stable training)
        normalizer = torch.matmul(Q_feat, K_feat.sum(dim=-2, keepdim=True).transpose(-2, -1))
        out = out / (normalizer + 1e-8)
        
        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out

print("Linear attention base class defined!")
print("Key insight: O(n²) → O(n) complexity through kernel trick")

Linear attention base class defined!
Key insight: O(n²) → O(n) complexity through kernel trick


## 3. Performer: Random Feature Attention

### The Performer Approach

**Paper**: "Rethinking Attention with Performers" (Choromanski et al., 2020)

### Key Innovation: FAVOR+ Algorithm

Performer approximates the softmax kernel using random Fourier features:

$$\phi(x) = \frac{1}{\sqrt{m}} \left[ \exp(w_1^T x), \exp(w_2^T x), \ldots, \exp(w_m^T x) \right]$$

Where $w_i \sim \mathcal{N}(0, I)$ are random Gaussian vectors.

### Mathematical Properties:
- **Unbiased**: $\mathbb{E}[\phi(q)^T \phi(k)] = \exp(q^T k)$
- **Positive**: All features are positive (crucial for attention)
- **Fast**: Linear complexity in sequence length

In [4]:
class PerformerAttention(LinearAttentionBase):
    """Performer attention using FAVOR+ algorithm with random features."""
    
    def __init__(self, d_model: int, n_heads: int = 8, feature_dim: int = 256, 
                 redraw_features: bool = True):
        super().__init__(d_model, n_heads, feature_dim)
        self.redraw_features = redraw_features
        self.register_buffer('random_features', None)
        self._create_projection_matrix()
    
    def _create_projection_matrix(self):
        """Create random projection matrix for FAVOR+."""
        # Random Gaussian matrix
        random_matrix = torch.randn(self.feature_dim, self.d_k)
        
        # Optional: Use structured random matrices (faster)
        # This is an advanced optimization from the paper
        self.register_buffer('projection_matrix', random_matrix)
    
    def feature_map(self, x: torch.Tensor) -> torch.Tensor:
        """FAVOR+ feature mapping with positive random features."""
        # x shape: [batch, heads, seq_len, d_k]
        
        # Project to random features
        projected = torch.matmul(x, self.projection_matrix.T)  # [batch, heads, seq_len, feature_dim]
        
        # Apply nonlinearity: exp(proj - max) for numerical stability
        max_proj = torch.max(projected, dim=-1, keepdim=True)[0]
        projected = projected - max_proj
        
        # Positive random features
        features = torch.exp(projected)
        
        # Normalization for unbiased estimation
        norm_factor = math.sqrt(self.feature_dim)
        features = features / norm_factor
        
        return features
    
    def redraw_projection_matrix(self):
        """Redraw random features (useful during training)."""
        self._create_projection_matrix()


class PerformerWithOrthogonal(PerformerAttention):
    """Enhanced Performer with orthogonal random features for better approximation."""
    
    def _create_projection_matrix(self):
        """Create orthogonal random projection matrix."""
        # Generate random matrix
        random_matrix = torch.randn(self.feature_dim, self.d_k)
        
        # Make it orthogonal using QR decomposition
        if self.feature_dim >= self.d_k:
            q, _ = torch.qr(random_matrix.T)
            random_matrix = q.T[:self.feature_dim]
        
        self.register_buffer('projection_matrix', random_matrix)


# Test Performer implementation
def test_performer():
    """Test Performer attention implementation."""
    d_model = 512
    seq_len = 1024
    batch_size = 2
    
    # Create models
    standard_attn = StandardAttention(d_model)
    performer_attn = PerformerAttention(d_model, feature_dim=256)
    performer_ortho = PerformerWithOrthogonal(d_model, feature_dim=256)
    
    # Test input
    x = torch.randn(batch_size, seq_len, d_model)
    
    print("Testing Performer Attention:")
    print(f"Input shape: {x.shape}")
    
    # Time standard attention
    start_time = time.time()
    with torch.no_grad():
        out_std, weights_std = standard_attn(x)
    std_time = time.time() - start_time
    
    # Time Performer
    start_time = time.time()
    with torch.no_grad():
        out_perf = performer_attn(x)
    perf_time = time.time() - start_time
    
    # Time Orthogonal Performer
    start_time = time.time()
    with torch.no_grad():
        out_ortho = performer_ortho(x)
    ortho_time = time.time() - start_time
    
    print(f"Standard Attention: {std_time:.4f}s")
    print(f"Performer: {perf_time:.4f}s ({std_time/perf_time:.1f}x speedup)")
    print(f"Performer (Orthogonal): {ortho_time:.4f}s ({std_time/ortho_time:.1f}x speedup)")
    
    # Check output similarity
    mse_perf = F.mse_loss(out_std, out_perf)
    mse_ortho = F.mse_loss(out_std, out_ortho)
    
    print(f"\nOutput similarity (MSE with standard):")
    print(f"Performer: {mse_perf:.6f}")
    print(f"Performer (Orthogonal): {mse_ortho:.6f}")

test_performer()

Testing Performer Attention:
Input shape: torch.Size([2, 1024, 512])
Standard Attention: 0.1759s
Performer: 0.0294s (6.0x speedup)
Performer (Orthogonal): 0.0111s (15.9x speedup)

Output similarity (MSE with standard):
Performer: 0.004882
Performer (Orthogonal): 0.001634


## 4. LinFormer: Low-Rank Attention

### The LinFormer Approach

**Paper**: "Linformer: Self-Attention with Linear Complexity" (Wang et al., 2020)

### Key Innovation: Projected Keys and Values

LinFormer reduces attention complexity by projecting keys and values to a lower dimension:

$$\text{LinFormer}(Q, K, V) = \text{softmax}\left(\frac{Q(PK)^T}{\sqrt{d_k}}\right)(PV)$$

Where $P \in \mathbb{R}^{k \times n}$ is a projection matrix with $k \ll n$.

### Complexity Analysis:
- **Standard**: $O(n^2 d)$
- **LinFormer**: $O(nkd)$ where $k \ll n$
- **Memory**: Reduced from $O(n^2)$ to $O(nk)$

In [5]:
class LinFormerAttention(nn.Module):
    """LinFormer attention with linear complexity through low-rank projection."""
    
    def __init__(self, d_model: int, n_heads: int = 8, projected_dim: int = 256, 
                 projection_type: str = 'linear'):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.projected_dim = projected_dim
        self.projection_type = projection_type
        
        # Standard Q, K, V projections
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
        
        # Projection matrices for keys and values
        if projection_type == 'linear':
            self.proj_k = nn.Linear(self.d_k, projected_dim, bias=False)
            self.proj_v = nn.Linear(self.d_k, projected_dim, bias=False)
        elif projection_type == 'conv':
            self.proj_k = nn.Conv1d(self.d_k, projected_dim, 1, bias=False)
            self.proj_v = nn.Conv1d(self.d_k, projected_dim, 1, bias=False)
        
        self.scale = math.sqrt(self.d_k)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape
        
        # Linear projections
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Project keys and values to lower dimension
        if self.projection_type == 'linear':
            # Apply projection along sequence dimension
            K_proj = self.proj_k(K.transpose(-2, -1)).transpose(-2, -1)  # [batch, heads, projected_dim, d_k]
            V_proj = self.proj_v(V.transpose(-2, -1)).transpose(-2, -1)  # [batch, heads, projected_dim, d_k]
        elif self.projection_type == 'conv':
            # Reshape for conv1d
            K_reshaped = K.view(-1, self.d_k, seq_len)
            V_reshaped = V.view(-1, self.d_k, seq_len)
            
            K_proj = self.proj_k(K_reshaped).view(batch_size, self.n_heads, self.projected_dim, -1).transpose(-2, -1)
            V_proj = self.proj_v(V_reshaped).view(batch_size, self.n_heads, self.projected_dim, -1).transpose(-2, -1)
        
        # Attention computation with reduced complexity
        scores = torch.matmul(Q, K_proj.transpose(-2, -1)) / self.scale  # [batch, heads, seq_len, projected_dim]
        
        if mask is not None:
            # Adapt mask for projected dimension
            projected_mask = mask[:, :self.projected_dim] if mask.size(-1) > self.projected_dim else mask
            scores = scores.masked_fill(projected_mask.unsqueeze(1).unsqueeze(1) == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to projected values
        out = torch.matmul(attention_weights, V_proj)  # [batch, heads, seq_len, d_k]
        
        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out


class AdaptiveLinFormer(LinFormerAttention):
    """LinFormer with adaptive projection dimension based on sequence length."""
    
    def __init__(self, d_model: int, n_heads: int = 8, max_projected_dim: int = 512,
                 compression_ratio: float = 0.25):
        # Start with max dimension, will be adjusted in forward pass
        super().__init__(d_model, n_heads, max_projected_dim)
        self.max_projected_dim = max_projected_dim
        self.compression_ratio = compression_ratio
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        seq_len = x.size(1)
        # Adaptive projection dimension
        self.projected_dim = min(self.max_projected_dim, int(seq_len * self.compression_ratio))
        self.projected_dim = max(self.projected_dim, 64)  # Minimum dimension
        
        return super().forward(x, mask)


# Test LinFormer implementation
def test_linformer():
    """Test LinFormer attention implementation."""
    d_model = 512
    batch_size = 2
    
    # Test different sequence lengths
    seq_lengths = [256, 512, 1024, 2048]
    
    print("LinFormer Complexity Analysis:")
    print("Seq Len | Proj Dim | Standard Time | LinFormer Time | Speedup")
    print("-" * 70)
    
    for seq_len in seq_lengths:
        x = torch.randn(batch_size, seq_len, d_model)
        projected_dim = min(256, seq_len // 4)  # 25% compression
        
        # Standard attention
        standard_attn = StandardAttention(d_model)
        start_time = time.time()
        with torch.no_grad():
            out_std, _ = standard_attn(x)
        std_time = time.time() - start_time
        
        # LinFormer attention
        linformer_attn = LinFormerAttention(d_model, projected_dim=projected_dim)
        start_time = time.time()
        with torch.no_grad():
            out_lin = linformer_attn(x)
        lin_time = time.time() - start_time
        
        speedup = std_time / lin_time if lin_time > 0 else float('inf')
        
        print(f"{seq_len:7d} | {projected_dim:8d} | {std_time:11.4f}s | {lin_time:12.4f}s | {speedup:6.1f}x")
        
        # Check output similarity
        mse = F.mse_loss(out_std, out_lin)
        print(f"        | MSE vs Standard: {mse:.6f}")

test_linformer()

LinFormer Complexity Analysis:
Seq Len | Proj Dim | Standard Time | LinFormer Time | Speedup
----------------------------------------------------------------------


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x256 and 64x64)

## 5. Other Linear Attention Variants

### FNet: Fourier Transform Attention
Replaces attention entirely with FFT operations.

### Synthesizer: Learned Synthetic Attention
Learns attention patterns without content-based interactions.

### Linear Transformer: Causal Linear Attention
Efficient linear attention for autoregressive tasks.

In [None]:
class FNetBlock(nn.Module):
    """FNet: Replace attention with 2D Fourier Transform."""
    
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply 2D FFT
        # x shape: [batch, seq_len, d_model]
        x_fft = torch.fft.fft2(x.float())
        x_real = x_fft.real
        
        return self.norm(x_real)


class SynthesizerAttention(nn.Module):
    """Synthesizer: Learned synthetic attention patterns."""
    
    def __init__(self, d_model: int, n_heads: int = 8, max_seq_len: int = 2048,
                 synthesizer_type: str = 'dense'):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.max_seq_len = max_seq_len
        self.synthesizer_type = synthesizer_type
        
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
        
        if synthesizer_type == 'dense':
            # Dense synthesizer: learned attention weights
            self.synthetic_attn = nn.Parameter(
                torch.randn(n_heads, max_seq_len, max_seq_len) * 0.02
            )
        elif synthesizer_type == 'random':
            # Random synthesizer: fixed random attention
            self.register_buffer(
                'synthetic_attn',
                torch.randn(n_heads, max_seq_len, max_seq_len)
            )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape
        
        # Only compute values (no queries or keys needed!)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Use synthetic attention patterns
        attn_weights = self.synthetic_attn[:, :seq_len, :seq_len]
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        # Apply synthetic attention
        out = torch.matmul(attn_weights.unsqueeze(0), V)  # Broadcast across batch
        
        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out


class LinearTransformerAttention(LinearAttentionBase):
    """Linear Transformer with causal masking for autoregressive tasks."""
    
    def __init__(self, d_model: int, n_heads: int = 8, feature_dim: int = 256,
                 causal: bool = True):
        super().__init__(d_model, n_heads, feature_dim)
        self.causal = causal
    
    def feature_map(self, x: torch.Tensor) -> torch.Tensor:
        """ELU-based feature mapping for positive features."""
        return F.elu(x) + 1
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.causal:
            return super().forward(x)
        
        # Causal version using cumulative sums
        batch_size, seq_len, d_model = x.shape
        
        # Linear projections
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply feature mapping
        Q_feat = self.feature_map(Q)
        K_feat = self.feature_map(K)
        
        # Causal linear attention using cumulative sums
        KV = K_feat.unsqueeze(-1) * V.unsqueeze(-2)  # Outer product
        KV_cumsum = torch.cumsum(KV, dim=2)  # Cumulative sum over sequence
        
        # Compute output
        out = torch.sum(Q_feat.unsqueeze(-1) * KV_cumsum, dim=-2)
        
        # Normalization
        normalizer = torch.cumsum(K_feat, dim=2)
        normalizer = torch.sum(Q_feat * normalizer, dim=-1, keepdim=True)
        out = out / (normalizer + 1e-8)
        
        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out


# Test all variants
def test_all_variants():
    """Test all linear attention variants."""
    d_model = 512
    seq_len = 1024
    batch_size = 2
    x = torch.randn(batch_size, seq_len, d_model)
    
    models = {
        'FNet': FNetBlock(d_model),
        'Synthesizer (Dense)': SynthesizerAttention(d_model, synthesizer_type='dense'),
        'Synthesizer (Random)': SynthesizerAttention(d_model, synthesizer_type='random'),
        'Linear Transformer': LinearTransformerAttention(d_model, causal=True),
    }
    
    print("Testing Alternative Linear Attention Variants:")
    print("Model\t\t\t| Time (ms) | Output Shape")
    print("-" * 50)
    
    for name, model in models.items():
        start_time = time.time()
        with torch.no_grad():
            output = model(x)
        elapsed = (time.time() - start_time) * 1000
        
        print(f"{name:20s} | {elapsed:7.2f}ms | {tuple(output.shape)}")

test_all_variants()

## 6. Comprehensive Comparison and Analysis

### Attention Quality vs Computational Complexity

Let's systematically compare all attention mechanisms across multiple dimensions:
1. **Computational Complexity**
2. **Memory Usage**
3. **Attention Quality (approximation error)**
4. **Scaling Behavior**

In [None]:
import pandas as pd
from collections import defaultdict

def comprehensive_benchmark():
    """Comprehensive benchmark of all attention mechanisms."""
    
    d_model = 512
    batch_size = 4
    seq_lengths = [128, 256, 512, 1024, 2048]
    
    # Initialize models
    models = {
        'Standard': StandardAttention(d_model),
        'Performer': PerformerAttention(d_model, feature_dim=256),
        'Performer (Ortho)': PerformerWithOrthogonal(d_model, feature_dim=256),
        'LinFormer': LinFormerAttention(d_model, projected_dim=256),
        'Linear Transformer': LinearTransformerAttention(d_model, causal=False),
        'FNet': FNetBlock(d_model),
    }
    
    results = defaultdict(list)
    
    print("Running comprehensive benchmark...")
    
    for seq_len in seq_lengths:
        print(f"\nTesting sequence length: {seq_len}")
        x = torch.randn(batch_size, seq_len, d_model)
        
        # Get standard attention output as reference
        with torch.no_grad():
            if seq_len <= 1024:  # Only compute for manageable sizes
                ref_output, ref_weights = models['Standard'](x)
            else:
                ref_output = None
        
        for name, model in models.items():
            try:
                # Measure time
                times = []
                for _ in range(5):  # Multiple runs for stability
                    start_time = time.time()
                    with torch.no_grad():
                        if name == 'Standard':
                            output, _ = model(x)
                        else:
                            output = model(x)
                    times.append(time.time() - start_time)
                
                avg_time = np.mean(times) * 1000  # Convert to ms
                
                # Calculate approximation error
                if ref_output is not None and name != 'Standard':
                    mse_error = F.mse_loss(output, ref_output).item()
                    cosine_sim = F.cosine_similarity(
                        output.flatten(), ref_output.flatten(), dim=0
                    ).item()
                else:
                    mse_error = 0.0 if name == 'Standard' else float('nan')
                    cosine_sim = 1.0 if name == 'Standard' else float('nan')
                
                # Memory estimation (theoretical)
                if name == 'Standard':
                    memory_mb = (seq_len ** 2) * 4 / (1024 ** 2)  # Attention matrix
                elif name in ['Performer', 'Performer (Ortho)']:
                    memory_mb = seq_len * 256 * 4 / (1024 ** 2)  # Feature dimension
                elif name == 'LinFormer':
                    memory_mb = seq_len * 256 * 4 / (1024 ** 2)  # Projected dimension
                else:
                    memory_mb = seq_len * d_model * 4 / (1024 ** 2)  # Linear in seq_len
                
                # Store results
                results['Model'].append(name)
                results['Seq_Length'].append(seq_len)
                results['Time_ms'].append(avg_time)
                results['Memory_MB'].append(memory_mb)
                results['MSE_Error'].append(mse_error)
                results['Cosine_Sim'].append(cosine_sim)
                
            except Exception as e:
                print(f"Error with {name}: {e}")
                continue
    
    return pd.DataFrame(results)

# Run benchmark
benchmark_df = comprehensive_benchmark()
print("\nBenchmark completed!")
print(f"Total results: {len(benchmark_df)} measurements")

In [None]:
# Visualization of results
def plot_comprehensive_analysis(df):
    """Create comprehensive visualizations of attention mechanism comparison."""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Linear Attention Mechanisms: Comprehensive Analysis', fontsize=16, fontweight='bold')
    
    # 1. Time vs Sequence Length
    ax1 = axes[0, 0]
    for model in df['Model'].unique():
        model_data = df[df['Model'] == model]
        ax1.plot(model_data['Seq_Length'], model_data['Time_ms'], 'o-', label=model, linewidth=2, markersize=6)
    ax1.set_xlabel('Sequence Length')
    ax1.set_ylabel('Time (ms)')
    ax1.set_title('Computational Time vs Sequence Length')
    ax1.set_yscale('log')
    ax1.set_xscale('log')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Memory vs Sequence Length
    ax2 = axes[0, 1]
    for model in df['Model'].unique():
        model_data = df[df['Model'] == model]
        ax2.plot(model_data['Seq_Length'], model_data['Memory_MB'], 's-', label=model, linewidth=2, markersize=6)
    ax2.set_xlabel('Sequence Length')
    ax2.set_ylabel('Memory (MB)')
    ax2.set_title('Memory Usage vs Sequence Length')
    ax2.set_yscale('log')
    ax2.set_xscale('log')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Approximation Quality (MSE Error)
    ax3 = axes[0, 2]
    quality_df = df[df['Model'] != 'Standard'].dropna(subset=['MSE_Error'])
    if not quality_df.empty:
        for model in quality_df['Model'].unique():
            model_data = quality_df[quality_df['Model'] == model]
            ax3.plot(model_data['Seq_Length'], model_data['MSE_Error'], '^-', label=model, linewidth=2, markersize=6)
        ax3.set_xlabel('Sequence Length')
        ax3.set_ylabel('MSE Error vs Standard')
        ax3.set_title('Approximation Quality (Lower is Better)')
        ax3.set_yscale('log')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
    
    # 4. Cosine Similarity
    ax4 = axes[1, 0]
    similarity_df = df[df['Model'] != 'Standard'].dropna(subset=['Cosine_Sim'])
    if not similarity_df.empty:
        for model in similarity_df['Model'].unique():
            model_data = similarity_df[similarity_df['Model'] == model]
            ax4.plot(model_data['Seq_Length'], model_data['Cosine_Sim'], 'd-', label=model, linewidth=2, markersize=6)
        ax4.set_xlabel('Sequence Length')
        ax4.set_ylabel('Cosine Similarity with Standard')
        ax4.set_title('Output Similarity (Higher is Better)')
        ax4.set_ylim(0, 1.1)
        ax4.legend()
        ax4.grid(True, alpha=0.3)
    
    # 5. Time vs Quality Trade-off (at seq_len=1024)
    ax5 = axes[1, 1]
    tradeoff_df = df[(df['Seq_Length'] == 1024) & (df['Model'] != 'Standard')].dropna(subset=['MSE_Error'])
    if not tradeoff_df.empty:
        scatter = ax5.scatter(tradeoff_df['Time_ms'], tradeoff_df['MSE_Error'], 
                            s=100, alpha=0.7, c=range(len(tradeoff_df)), cmap='viridis')
        for i, model in enumerate(tradeoff_df['Model']):
            ax5.annotate(model, (tradeoff_df.iloc[i]['Time_ms'], tradeoff_df.iloc[i]['MSE_Error']),
                        xytext=(5, 5), textcoords='offset points', fontsize=9)
        ax5.set_xlabel('Time (ms)')
        ax5.set_ylabel('MSE Error')
        ax5.set_title('Speed vs Quality Trade-off (Seq Len=1024)')
        ax5.set_yscale('log')
        ax5.grid(True, alpha=0.3)
    
    # 6. Scaling Analysis (Time complexity)
    ax6 = axes[1, 2]
    
    # Fit polynomial to estimate complexity
    for model in df['Model'].unique():
        model_data = df[df['Model'] == model].sort_values('Seq_Length')
        if len(model_data) >= 3:
            x = np.log(model_data['Seq_Length'].values)
            y = np.log(model_data['Time_ms'].values)
            coeffs = np.polyfit(x, y, 1)
            complexity_order = coeffs[0]
            
            ax6.plot(model_data['Seq_Length'], model_data['Time_ms'], 'o-', 
                    label=f'{model} (O(n^{complexity_order:.1f}))', linewidth=2, markersize=6)
    
    ax6.set_xlabel('Sequence Length')
    ax6.set_ylabel('Time (ms)')
    ax6.set_title('Empirical Complexity Analysis')
    ax6.set_yscale('log')
    ax6.set_xscale('log')
    ax6.legend()
    ax6.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Create visualizations
plot_comprehensive_analysis(benchmark_df)

In [None]:
# Summary statistics and insights
def generate_insights(df):
    """Generate key insights from the benchmark results."""
    
    print("📊 KEY INSIGHTS FROM LINEAR ATTENTION ANALYSIS")
    print("=" * 60)
    
    # 1. Speed Analysis
    print("\n🚀 SPEED ANALYSIS:")
    max_seq = df['Seq_Length'].max()
    speed_data = df[df['Seq_Length'] == max_seq]
    standard_time = speed_data[speed_data['Model'] == 'Standard']['Time_ms'].iloc[0]
    
    print(f"At sequence length {max_seq}:")
    for _, row in speed_data.iterrows():
        if row['Model'] != 'Standard':
            speedup = standard_time / row['Time_ms']
            print(f"  • {row['Model']:20s}: {speedup:5.1f}x speedup")
    
    # 2. Memory Analysis
    print("\n💾 MEMORY EFFICIENCY:")
    memory_data = df[df['Seq_Length'] == max_seq]
    standard_memory = memory_data[memory_data['Model'] == 'Standard']['Memory_MB'].iloc[0]
    
    print(f"Memory reduction at sequence length {max_seq}:")
    for _, row in memory_data.iterrows():
        if row['Model'] != 'Standard':
            reduction = standard_memory / row['Memory_MB']
            print(f"  • {row['Model']:20s}: {reduction:5.1f}x less memory")
    
    # 3. Quality Analysis
    print("\n🎯 APPROXIMATION QUALITY:")
    quality_data = df[(df['Seq_Length'] == 1024) & (df['Model'] != 'Standard')].dropna(subset=['MSE_Error'])
    
    if not quality_data.empty:
        print("MSE Error vs Standard Attention (at seq_len=1024):")
        for _, row in quality_data.sort_values('MSE_Error').iterrows():
            print(f"  • {row['Model']:20s}: {row['MSE_Error']:.6f}")
        
        print("\nCosine Similarity with Standard (at seq_len=1024):")
        for _, row in quality_data.sort_values('Cosine_Sim', ascending=False).iterrows():
            print(f"  • {row['Model']:20s}: {row['Cosine_Sim']:.4f}")
    
    # 4. Scaling Analysis
    print("\n📈 SCALING BEHAVIOR:")
    print("Theoretical vs Empirical complexity:")
    
    complexity_theory = {
        'Standard': 'O(n²)',
        'Performer': 'O(n)',
        'Performer (Ortho)': 'O(n)',
        'LinFormer': 'O(n)',
        'Linear Transformer': 'O(n)',
        'FNet': 'O(n log n)'
    }
    
    for model in df['Model'].unique():
        model_data = df[df['Model'] == model].sort_values('Seq_Length')
        if len(model_data) >= 3:
            x = np.log(model_data['Seq_Length'].values)
            y = np.log(model_data['Time_ms'].values)
            coeffs = np.polyfit(x, y, 1)
            empirical_order = coeffs[0]
            theoretical = complexity_theory.get(model, 'Unknown')
            print(f"  • {model:20s}: {theoretical:8s} → O(n^{empirical_order:.1f}) empirical")
    
    # 5. Recommendations
    print("\n✅ RECOMMENDATIONS:")
    print("\n🎯 For SPEED priority:")
    fastest = speed_data.loc[speed_data[speed_data['Model'] != 'Standard']['Time_ms'].idxmin()]
    print(f"  → Use {fastest['Model']} (fastest linear attention)")
    
    print("\n🎯 For QUALITY priority:")
    if not quality_data.empty:
        best_quality = quality_data.loc[quality_data['MSE_Error'].idxmin()]
        print(f"  → Use {best_quality['Model']} (best approximation quality)")
    
    print("\n🎯 For BALANCED performance:")
    if not quality_data.empty:
        # Normalize metrics and compute composite score
        normalized_time = quality_data['Time_ms'] / quality_data['Time_ms'].max()
        normalized_error = quality_data['MSE_Error'] / quality_data['MSE_Error'].max()
        composite_score = normalized_time + normalized_error  # Lower is better
        balanced_idx = composite_score.idxmin()
        balanced_model = quality_data.loc[balanced_idx]
        print(f"  → Use {balanced_model['Model']} (best speed/quality trade-off)")
    
    print("\n🎯 For LONG sequences (>4K tokens):")
    print("  → Use Performer or LinFormer (proven scalability)")
    print("  → Avoid Standard Attention (memory constraints)")

generate_insights(benchmark_df)

## 7. Attention Pattern Visualization

### Understanding What Linear Attention "Sees"

Let's visualize and compare the attention patterns produced by different mechanisms to understand how the linear approximations differ from standard attention.

In [None]:
def visualize_attention_patterns():
    """Visualize and compare attention patterns across different mechanisms."""
    
    # Create synthetic data with clear patterns
    seq_len = 64  # Smaller for visualization
    d_model = 128
    
    # Create structured input with different patterns
    x = torch.zeros(1, seq_len, d_model)
    
    # Pattern 1: Periodic signal
    for i in range(seq_len):
        x[0, i, :d_model//4] = torch.sin(torch.arange(d_model//4, dtype=torch.float) * i * 0.1)
    
    # Pattern 2: Local dependencies
    for i in range(seq_len//2, seq_len):
        x[0, i, d_model//4:d_model//2] = x[0, i-1, d_model//4:d_model//2] * 0.9 + 0.1
    
    # Pattern 3: Long-range dependencies
    x[0, seq_len//4:seq_len//2, d_model//2:3*d_model//4] = x[0, 0:seq_len//4, d_model//2:3*d_model//4]
    
    # Initialize models
    models = {
        'Standard': StandardAttention(d_model, n_heads=1),  # Single head for clearer visualization
        'Performer': PerformerAttention(d_model, n_heads=1, feature_dim=64),
        'LinFormer': LinFormerAttention(d_model, n_heads=1, projected_dim=32),
    }
    
    # Get attention patterns
    attention_patterns = {}
    outputs = {}
    
    with torch.no_grad():
        for name, model in models.items():
            if name == 'Standard':
                output, attn_weights = model(x)
                attention_patterns[name] = attn_weights[0, 0].numpy()  # [seq_len, seq_len]
            else:
                output = model(x)
                # For linear attention, we'll compute approximate attention for visualization
                attention_patterns[name] = compute_approximate_attention(model, x)
            
            outputs[name] = output
    
    # Visualization
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Attention Pattern Comparison', fontsize=16, fontweight='bold')
    
    # Plot attention matrices
    for i, (name, pattern) in enumerate(attention_patterns.items()):
        ax = axes[0, i]
        im = ax.imshow(pattern, cmap='Blues', aspect='auto')
        ax.set_title(f'{name} Attention')
        ax.set_xlabel('Key Position')
        ax.set_ylabel('Query Position')
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # Plot attention distributions for specific queries
    query_positions = [16, 32, 48]  # Different positions
    colors = ['red', 'green', 'blue']
    
    for i, (name, pattern) in enumerate(attention_patterns.items()):
        ax = axes[1, i]
        
        for j, (pos, color) in enumerate(zip(query_positions, colors)):
            ax.plot(pattern[pos], label=f'Query {pos}', color=color, alpha=0.7, linewidth=2)
        
        ax.set_title(f'{name} Attention Distribution')
        ax.set_xlabel('Key Position')
        ax.set_ylabel('Attention Weight')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return attention_patterns, outputs


def compute_approximate_attention(model, x):
    """Compute approximate attention weights for linear attention models."""
    # This is a simplified approximation for visualization
    # In practice, linear attention doesn't compute explicit attention matrices
    
    batch_size, seq_len, d_model = x.shape
    
    # Get Q, K projections
    Q = model.w_q(x).view(batch_size, seq_len, 1, -1).transpose(1, 2)  # Single head
    K = model.w_k(x).view(batch_size, seq_len, 1, -1).transpose(1, 2)
    
    if hasattr(model, 'feature_map'):
        # For Performer-like models
        Q_feat = model.feature_map(Q)
        K_feat = model.feature_map(K)
        
        # Approximate attention as normalized dot product in feature space
        attn_approx = torch.matmul(Q_feat, K_feat.transpose(-2, -1))
        attn_approx = F.softmax(attn_approx, dim=-1)
        
        return attn_approx[0, 0].numpy()
    
    elif hasattr(model, 'proj_k'):
        # For LinFormer-like models
        K_proj = model.proj_k(K.transpose(-2, -1)).transpose(-2, -1)
        
        # Compute attention in projected space
        scores = torch.matmul(Q, K_proj.transpose(-2, -1)) / math.sqrt(Q.size(-1))
        attn_weights = F.softmax(scores, dim=-1)
        
        # Pad to original sequence length for visualization
        full_attn = torch.zeros(1, 1, seq_len, seq_len)
        proj_dim = K_proj.size(-2)
        full_attn[:, :, :, :proj_dim] = attn_weights
        
        return full_attn[0, 0].numpy()
    
    else:
        # Fallback: compute standard attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
        attn_weights = F.softmax(scores, dim=-1)
        return attn_weights[0, 0].numpy()


# Run attention pattern visualization
attention_patterns, attention_outputs = visualize_attention_patterns()

In [None]:
def analyze_attention_properties():
    """Analyze specific properties of different attention mechanisms."""
    
    print("🔍 ATTENTION PATTERN ANALYSIS")
    print("=" * 50)
    
    for name, pattern in attention_patterns.items():
        print(f"\n📊 {name} Attention Properties:")
        
        # 1. Sparsity (how concentrated is attention?)
        entropy = -np.sum(pattern * np.log(pattern + 1e-8), axis=-1).mean()
        print(f"  • Average Entropy: {entropy:.3f} (lower = more focused)")
        
        # 2. Locality (how much attention to nearby positions?)
        seq_len = pattern.shape[0]
        local_attention = 0
        for i in range(seq_len):
            # Attention to positions within ±5 of current position
            start = max(0, i - 5)
            end = min(seq_len, i + 6)
            local_attention += pattern[i, start:end].sum()
        local_attention /= seq_len
        print(f"  • Local Attention: {local_attention:.3f} (attention to nearby positions)")
        
        # 3. Long-range dependencies
        long_range = 0
        for i in range(seq_len):
            # Attention to positions >10 away
            far_positions = np.concatenate([
                np.arange(0, max(0, i - 10)),
                np.arange(min(seq_len, i + 11), seq_len)
            ])
            if len(far_positions) > 0:
                long_range += pattern[i, far_positions].sum()
        long_range /= seq_len
        print(f"  • Long-range Attention: {long_range:.3f} (attention to distant positions)")
        
        # 4. Uniformity (how uniform is the attention distribution?)
        uniformity = 1.0 / (1.0 + np.var(pattern.mean(axis=0)))
        print(f"  • Uniformity: {uniformity:.3f} (higher = more uniform)")
        
        # 5. Diagonal dominance (attention to same position)
        diagonal_strength = np.diag(pattern).mean()
        print(f"  • Self-attention: {diagonal_strength:.3f} (attention to same position)")
    
    # Compare outputs
    print("\n🎯 OUTPUT SIMILARITY ANALYSIS:")
    reference_output = attention_outputs['Standard']
    
    for name, output in attention_outputs.items():
        if name != 'Standard':
            # Cosine similarity
            cos_sim = F.cosine_similarity(
                reference_output.flatten(), output.flatten(), dim=0
            ).item()
            
            # MSE
            mse = F.mse_loss(reference_output, output).item()
            
            # Correlation coefficient
            ref_flat = reference_output.flatten().numpy()
            out_flat = output.flatten().numpy()
            correlation = np.corrcoef(ref_flat, out_flat)[0, 1]
            
            print(f"\n{name} vs Standard:")
            print(f"  • Cosine Similarity: {cos_sim:.4f}")
            print(f"  • MSE: {mse:.6f}")
            print(f"  • Correlation: {correlation:.4f}")

analyze_attention_properties()

## 8. Practical Implementation Guidelines

### When to Use Each Mechanism

Based on our analysis, here are practical guidelines for choosing linear attention mechanisms:

#### 🎯 **Decision Framework**

In [None]:
def attention_recommendation_system():
    """Interactive system to recommend attention mechanisms based on requirements."""
    
    print("🤖 ATTENTION MECHANISM RECOMMENDATION SYSTEM")
    print("=" * 55)
    print("Answer a few questions to get personalized recommendations!\n")
    
    # Define decision tree
    recommendations = {
        'sequence_length': {
            'short': '<= 512 tokens',
            'medium': '513-2048 tokens', 
            'long': '2049-8192 tokens',
            'very_long': '> 8192 tokens'
        },
        'priority': {
            'speed': 'Maximum computational speed',
            'memory': 'Minimal memory usage',
            'quality': 'Best approximation quality',
            'balanced': 'Balance of speed/memory/quality'
        },
        'task_type': {
            'classification': 'Text classification, sentiment analysis',
            'generation': 'Text generation, language modeling',
            'understanding': 'Reading comprehension, QA',
            'translation': 'Machine translation, seq2seq'
        }
    }
    
    # Recommendation logic
    def get_recommendation(seq_len, priority, task):
        recommendations = []
        
        if seq_len == 'short':
            recommendations.append(("Standard Attention", "Manageable complexity, best quality", "⭐⭐⭐"))
            if priority in ['speed', 'memory']:
                recommendations.append(("Performer", "Good approximation with speed benefits", "⭐⭐"))
        
        elif seq_len == 'medium':
            if priority == 'quality':
                recommendations.append(("Performer (Orthogonal)", "Best linear approximation", "⭐⭐⭐"))
            elif priority == 'speed':
                recommendations.append(("LinFormer", "Excellent speed-quality balance", "⭐⭐⭐"))
            elif priority == 'memory':
                recommendations.append(("Linear Transformer", "Memory efficient for generation", "⭐⭐⭐"))
            else:  # balanced
                recommendations.append(("Performer", "Good all-around performance", "⭐⭐⭐"))
        
        elif seq_len == 'long':
            recommendations.append(("LinFormer", "Proven scalability, good quality", "⭐⭐⭐"))
            recommendations.append(("Performer", "Fast approximation", "⭐⭐"))
            if task == 'generation':
                recommendations.append(("Linear Transformer", "Causal attention for generation", "⭐⭐"))
        
        else:  # very_long
            recommendations.append(("LinFormer", "Only viable option for very long sequences", "⭐⭐⭐"))
            if priority == 'speed':
                recommendations.append(("FNet", "Ultra-fast for specific tasks", "⭐⭐"))
        
        return recommendations
    
    # Example scenarios
    scenarios = [
        ('medium', 'balanced', 'classification', "Document Classification (1024 tokens)"),
        ('long', 'speed', 'generation', "Long-form Text Generation (4096 tokens)"),
        ('short', 'quality', 'understanding', "Reading Comprehension (256 tokens)"),
        ('very_long', 'memory', 'translation', "Long Document Translation (16K tokens)")
    ]
    
    for seq_len, priority, task, description in scenarios:
        print(f"📋 Scenario: {description}")
        print(f"   Sequence Length: {recommendations['sequence_length'][seq_len]}")
        print(f"   Priority: {recommendations['priority'][priority]}")
        print(f"   Task: {recommendations['task_type'][task]}")
        print("\n   🎯 Recommendations:")
        
        recs = get_recommendation(seq_len, priority, task)
        for i, (method, reason, rating) in enumerate(recs, 1):
            print(f"      {i}. {method} {rating}")
            print(f"         → {reason}")
        print("-" * 60)

attention_recommendation_system()

## 9. Implementation Tips and Best Practices

### Training Considerations

Here are key implementation tips for successful deployment of linear attention mechanisms:

In [None]:
# Final demonstration: Simple usage example
def simple_usage_example():
    """Demonstrate simple usage of KV caching."""
    
    print("=== Simple Usage Example ===")
    print("This example shows how to use KV caching in your own projects.\n")
    
    # Initialize model
    model = CachedAttention(d_model=256, n_heads=8)
    
    # Generate text step by step
    batch_size, d_model = 1, 256
    kv_cache = None
    
    print("Generating sequence step by step:")
    
    for step in range(5):
        # Simulate new token
        new_token = torch.randn(batch_size, 1, d_model)
        
        # Process with caching
        output, kv_cache = model(
            new_token, 
            kv_cache=kv_cache, 
            use_cache=True
        )
        
        cache_info = kv_cache.get_cache_info() if kv_cache else "No cache"
        print(f"Step {step + 1}: Processed token, {cache_info}")
    
    print("\n✅ Successfully demonstrated KV caching!")
    print("\n🎯 Next steps:")
    print("   • Integrate into your transformer model")
    print("   • Benchmark on your specific use case")
    print("   • Consider advanced optimizations based on your requirements")

simple_usage_example()