# FlashAttention: Fast and Memory-Efficient Exact Attention

## 🎯 Overview

FlashAttention is a groundbreaking algorithm that revolutionized how we compute attention in transformers. It reduces memory complexity from O(N²) to O(N) while maintaining mathematical exactness, enabling training on much longer sequences.

**Key Innovation**: IO-aware computation that uses block-wise processing and recomputation to minimize memory access patterns.

**Impact**: Universal adoption in major frameworks (PyTorch, JAX), enabling 8x longer sequences with the same memory.

## 📚 Background & Motivation

### The Memory Wall Problem
- Standard attention requires O(N²) memory for the attention matrix
- Memory access is often the bottleneck, not computation
- GPU memory hierarchy: HBM (slow, large) vs SRAM (fast, small)
- Naive attention repeatedly reads/writes to slow HBM memory

### The FlashAttention Solution
- Block-wise computation that fits in fast SRAM
- Online softmax algorithm for numerical stability
- Recomputation in backward pass to save memory
- Exact attention computation (not an approximation)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import seaborn as sns
from typing import Tuple, Optional
import math

# Set style
plt.style.use('default')
sns.set_palette("husl")
np.random.seed(42)
torch.manual_seed(42)

print("📦 Libraries imported successfully!")
print(f"🔢 NumPy version: {np.__version__}")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"💾 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🚀 GPU: {torch.cuda.get_device_name(0)}")

## 🧮 Mathematical Foundation

### Standard Attention
The standard attention computation involves:

1. **Compute attention scores**: S = QK^T / √d
2. **Apply softmax**: P = softmax(S)
3. **Apply to values**: O = PV

**Memory complexity**: O(N²) for storing S and P

### FlashAttention Key Insights

1. **Online Softmax**: Compute softmax incrementally without storing full matrix
2. **Block-wise Processing**: Process attention in blocks that fit in SRAM
3. **Recomputation**: Trade computation for memory in backward pass

### Online Softmax Algorithm

For computing softmax(x) incrementally:
- **m_new = max(m_old, x_new)**
- **d_new = d_old × exp(m_old - m_new) + exp(x_new - m_new)**
- **Update previous values with correction factor**

In [None]:
def naive_attention(Q, K, V, mask=None):
    """
    Standard attention implementation - O(N²) memory.
    """
    batch_size, seq_len, head_dim = Q.shape
    scale = 1.0 / math.sqrt(head_dim)
    
    # Compute attention scores - O(N²) memory
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Softmax - stores full attention matrix
    attn_weights = F.softmax(scores, dim=-1)
    
    # Apply to values
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

def online_softmax(x, dim=-1):
    """
    Demonstrate online softmax computation.
    """
    # Standard softmax for comparison
    standard_softmax = F.softmax(x, dim=dim)
    
    # Online softmax simulation
    if dim == -1:
        dim = x.dim() - 1
    
    # Initialize
    m = torch.full_like(x[..., :1], -float('inf'))
    d = torch.zeros_like(x[..., :1])
    output = torch.zeros_like(x)
    
    # Process each element
    for i in range(x.size(dim)):
        x_i = x[..., i:i+1]
        
        # Update max
        m_new = torch.maximum(m, x_i)
        
        # Update denominator
        d_new = d * torch.exp(m - m_new) + torch.exp(x_i - m_new)
        
        # Update previous outputs
        if i > 0:
            correction = torch.exp(m - m_new)
            output[..., :i] = output[..., :i] * correction
        
        # Compute current output
        output[..., i:i+1] = torch.exp(x_i - m_new)
        
        # Update states
        m = m_new
        d = d_new
    
    # Final normalization
    output = output / d
    
    return output, standard_softmax

# Test online softmax
test_input = torch.randn(2, 8)
online_result, standard_result = online_softmax(test_input)

print("🧮 Online Softmax Test:")
print(f"   Max difference: {torch.max(torch.abs(online_result - standard_result)).item():.2e}")
print(f"   Results match: {torch.allclose(online_result, standard_result, atol=1e-6)}")

# Visualize the difference
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))

# Input
im1 = ax1.imshow(test_input.numpy(), aspect='auto', cmap='coolwarm')
ax1.set_title('Input')
ax1.set_xlabel('Sequence Position')
ax1.set_ylabel('Batch')
plt.colorbar(im1, ax=ax1)

# Standard softmax
im2 = ax2.imshow(standard_result.numpy(), aspect='auto', cmap='viridis')
ax2.set_title('Standard Softmax')
ax2.set_xlabel('Sequence Position')
plt.colorbar(im2, ax=ax2)

# Difference
diff = torch.abs(online_result - standard_result).numpy()
im3 = ax3.imshow(diff, aspect='auto', cmap='Reds')
ax3.set_title('Absolute Difference')
ax3.set_xlabel('Sequence Position')
plt.colorbar(im3, ax=ax3)

plt.tight_layout()
plt.show()

## ⚡ FlashAttention Algorithm Implementation

Let's implement a simplified version of FlashAttention to understand the core concepts.

In [None]:
def flash_attention_forward(Q, K, V, block_size=64, mask=None):
    """
    Simplified FlashAttention forward pass.
    
    Args:
        Q, K, V: Query, Key, Value tensors [batch, seq_len, head_dim]
        block_size: Size of blocks for computation
        mask: Optional attention mask
    
    Returns:
        output: Attention output
        max_memory: Peak memory usage during computation
    """
    batch_size, seq_len, head_dim = Q.shape
    scale = 1.0 / math.sqrt(head_dim)
    
    # Initialize output and statistics
    O = torch.zeros_like(Q)
    l = torch.zeros(batch_size, seq_len, 1, device=Q.device)  # Row sums
    m = torch.full((batch_size, seq_len, 1), -float('inf'), device=Q.device)  # Row maxes
    
    max_memory_used = 0
    
    # Process in blocks
    for j in range(0, seq_len, block_size):
        j_end = min(j + block_size, seq_len)
        K_j = K[:, j:j_end, :]
        V_j = V[:, j:j_end, :]
        
        for i in range(0, seq_len, block_size):
            i_end = min(i + block_size, seq_len)
            Q_i = Q[:, i:i_end, :]
            
            # Compute block attention scores
            S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1)) * scale
            
            # Apply mask if provided
            if mask is not None:
                mask_block = mask[:, i:i_end, j:j_end]
                S_ij = S_ij.masked_fill(mask_block == 0, -1e9)
            
            # Online softmax update
            m_prev = m[:, i:i_end, :].clone()
            l_prev = l[:, i:i_end, :].clone()
            
            # Update row maxes
            m_new = torch.maximum(m_prev, S_ij.max(dim=-1, keepdim=True)[0])
            
            # Compute attention weights for this block
            P_ij = torch.exp(S_ij - m_new)
            
            # Update row sums
            l_new = torch.exp(m_prev - m_new) * l_prev + P_ij.sum(dim=-1, keepdim=True)
            
            # Update output
            correction = torch.exp(m_prev - m_new) * l_prev / l_new
            O[:, i:i_end, :] = O[:, i:i_end, :] * correction + \
                              torch.matmul(P_ij, V_j) / l_new
            
            # Update statistics
            m[:, i:i_end, :] = m_new
            l[:, i:i_end, :] = l_new
            
            # Track memory usage (simplified)
            current_memory = S_ij.numel() + P_ij.numel()
            max_memory_used = max(max_memory_used, current_memory)
    
    return O, max_memory_used

def compare_attention_methods(seq_lengths, head_dim=64, batch_size=2, block_size=64):
    """
    Compare naive attention vs FlashAttention in terms of memory and speed.
    """
    results = {
        'seq_lengths': [],
        'naive_memory': [],
        'flash_memory': [],
        'naive_time': [],
        'flash_time': [],
        'memory_reduction': [],
        'max_diff': []
    }
    
    for seq_len in seq_lengths:
        print(f"\n🔄 Testing sequence length: {seq_len}")
        
        # Generate test data
        Q = torch.randn(batch_size, seq_len, head_dim)
        K = torch.randn(batch_size, seq_len, head_dim)
        V = torch.randn(batch_size, seq_len, head_dim)
        
        # Naive attention
        start_time = time.time()
        naive_output, naive_weights = naive_attention(Q, K, V)
        naive_time = time.time() - start_time
        naive_memory = seq_len * seq_len * batch_size  # Attention matrix size
        
        # FlashAttention
        start_time = time.time()
        flash_output, flash_memory = flash_attention_forward(Q, K, V, block_size)
        flash_time = time.time() - start_time
        
        # Compare outputs
        max_diff = torch.max(torch.abs(naive_output - flash_output)).item()
        
        # Store results
        results['seq_lengths'].append(seq_len)
        results['naive_memory'].append(naive_memory)
        results['flash_memory'].append(flash_memory)
        results['naive_time'].append(naive_time)
        results['flash_time'].append(flash_time)
        results['memory_reduction'].append(naive_memory / flash_memory)
        results['max_diff'].append(max_diff)
        
        print(f"   Naive memory: {naive_memory:,} elements")
        print(f"   Flash memory: {flash_memory:,} elements")
        print(f"   Memory reduction: {naive_memory / flash_memory:.1f}x")
        print(f"   Max difference: {max_diff:.2e}")
        print(f"   Outputs match: {max_diff < 1e-4}")
    
    return results

# Test with different sequence lengths
test_seq_lengths = [128, 256, 512, 1024]
comparison_results = compare_attention_methods(test_seq_lengths)

print("\n✅ FlashAttention comparison completed!")

## 📊 Memory and Performance Analysis

In [None]:
# Visualize comparison results
def plot_attention_comparison(results):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    seq_lengths = results['seq_lengths']
    
    # 1. Memory usage comparison
    ax1.plot(seq_lengths, results['naive_memory'], 'o-', label='Naive Attention', linewidth=2)
    ax1.plot(seq_lengths, results['flash_memory'], 's-', label='FlashAttention', linewidth=2)
    ax1.set_xlabel('Sequence Length')
    ax1.set_ylabel('Memory Usage (elements)')
    ax1.set_title('Memory Usage Comparison')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_yscale('log')
    
    # 2. Memory reduction factor
    ax2.plot(seq_lengths, results['memory_reduction'], 'o-', color='green', linewidth=2)
    ax2.set_xlabel('Sequence Length')
    ax2.set_ylabel('Memory Reduction Factor')
    ax2.set_title('Memory Reduction (Naive / Flash)')
    ax2.grid(True, alpha=0.3)
    
    # Add annotations
    for i, (x, y) in enumerate(zip(seq_lengths, results['memory_reduction'])):
        ax2.annotate(f'{y:.1f}x', (x, y), textcoords="offset points", 
                    xytext=(0,10), ha='center')
    
    # 3. Time comparison
    ax3.plot(seq_lengths, results['naive_time'], 'o-', label='Naive Attention', linewidth=2)
    ax3.plot(seq_lengths, results['flash_time'], 's-', label='FlashAttention', linewidth=2)
    ax3.set_xlabel('Sequence Length')
    ax3.set_ylabel('Time (seconds)')
    ax3.set_title('Computation Time Comparison')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Accuracy (max difference)
    ax4.semilogy(seq_lengths, results['max_diff'], 'o-', color='red', linewidth=2)
    ax4.axhline(y=1e-4, color='orange', linestyle='--', label='Tolerance (1e-4)')
    ax4.set_xlabel('Sequence Length')
    ax4.set_ylabel('Max Absolute Difference')
    ax4.set_title('Numerical Accuracy')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_attention_comparison(comparison_results)

# Create theoretical analysis
def theoretical_memory_analysis():
    """
    Analyze theoretical memory complexity.
    """
    seq_lengths = np.array([128, 256, 512, 1024, 2048, 4096, 8192])
    block_size = 64
    
    # Naive attention: O(N²)
    naive_memory = seq_lengths ** 2
    
    # FlashAttention: O(N) with block size overhead
    flash_memory = seq_lengths * block_size
    
    # Memory reduction factor
    reduction_factor = naive_memory / flash_memory
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Memory complexity
    ax1.loglog(seq_lengths, naive_memory, 'o-', label='Naive O(N²)', linewidth=2)
    ax1.loglog(seq_lengths, flash_memory, 's-', label='Flash O(N)', linewidth=2)
    ax1.set_xlabel('Sequence Length')
    ax1.set_ylabel('Memory Usage (relative)')
    ax1.set_title('Theoretical Memory Complexity')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Reduction factor
    ax2.semilogx(seq_lengths, reduction_factor, 'o-', color='green', linewidth=2)
    ax2.set_xlabel('Sequence Length')
    ax2.set_ylabel('Memory Reduction Factor')
    ax2.set_title('Theoretical Memory Reduction')
    ax2.grid(True, alpha=0.3)
    
    # Add trend line
    z = np.polyfit(np.log(seq_lengths), np.log(reduction_factor), 1)
    trend = np.exp(z[1]) * seq_lengths ** z[0]
    ax2.plot(seq_lengths, trend, '--', color='red', alpha=0.7, 
             label=f'Trend: {np.exp(z[1]):.1f} × N^{z[0]:.2f}')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
    return seq_lengths, naive_memory, flash_memory, reduction_factor

seq_lens, naive_mem, flash_mem, reduction = theoretical_memory_analysis()

print("\n📊 Theoretical Memory Analysis:")
print("=" * 60)
print(f"{'Seq Len':<10} {'Naive Mem':<12} {'Flash Mem':<12} {'Reduction':<12}")
print("=" * 60)
for i, seq_len in enumerate(seq_lens):
    if seq_len <= 2048:  # Only show manageable sizes
        print(f"{seq_len:<10} {naive_mem[i]:<12,.0f} {flash_mem[i]:<12,.0f} {reduction[i]:<12.1f}x")

## 🔧 FlashAttention-2 Improvements

FlashAttention-2 introduced several optimizations:
1. **Reduced non-matmul operations** 
2. **Better parallelization** across sequence length
3. **Improved work partitioning** between thread blocks

In [None]:
class FlashAttentionModule(nn.Module):
    """
    FlashAttention module with improved memory efficiency.
    """
    
    def __init__(self, embed_dim, num_heads, block_size=64, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.block_size = block_size
        self.scale = 1.0 / math.sqrt(self.head_dim)
        
        assert embed_dim % num_heads == 0
        
        # Projections
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, attention_mask=None, use_flash=True):
        batch_size, seq_len, embed_dim = x.shape
        
        # Project to Q, K, V
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        if use_flash and seq_len > self.block_size:
            # Use FlashAttention for long sequences
            attn_output = self._flash_attention(Q, K, V, attention_mask)
        else:
            # Use standard attention for short sequences
            attn_output = self._standard_attention(Q, K, V, attention_mask)
        
        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, embed_dim
        )
        
        # Output projection
        output = self.out_proj(attn_output)
        
        return output
    
    def _standard_attention(self, Q, K, V, mask=None):
        """Standard attention implementation."""
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        output = torch.matmul(attn_weights, V)
        return output
    
    def _flash_attention(self, Q, K, V, mask=None):
        """FlashAttention implementation."""
        batch_size, num_heads, seq_len, head_dim = Q.shape
        
        # Flatten batch and heads for processing
        Q_flat = Q.view(-1, seq_len, head_dim)
        K_flat = K.view(-1, seq_len, head_dim)
        V_flat = V.view(-1, seq_len, head_dim)
        
        # Apply FlashAttention
        output_flat, _ = flash_attention_forward(
            Q_flat, K_flat, V_flat, self.block_size, mask
        )
        
        # Reshape back
        output = output_flat.view(batch_size, num_heads, seq_len, head_dim)
        
        return output
    
    def get_memory_stats(self, seq_len):
        """Estimate memory usage for different sequence lengths."""
        # Standard attention memory
        standard_memory = self.num_heads * seq_len * seq_len
        
        # FlashAttention memory (block-wise)
        flash_memory = self.num_heads * self.block_size * self.block_size
        
        return {
            'standard': standard_memory,
            'flash': flash_memory,
            'reduction': standard_memory / flash_memory
        }

# Test FlashAttention module
flash_attn = FlashAttentionModule(
    embed_dim=512,
    num_heads=8,
    block_size=64
)

# Test with different sequence lengths
test_lengths = [128, 256, 512, 1024]

print("🧪 FlashAttention Module Test:")
print("=" * 50)

for seq_len in test_lengths:
    x = torch.randn(2, seq_len, 512)
    
    # Test both modes
    start_time = time.time()
    output_standard = flash_attn(x, use_flash=False)
    time_standard = time.time() - start_time
    
    start_time = time.time()
    output_flash = flash_attn(x, use_flash=True)
    time_flash = time.time() - start_time
    
    # Compare outputs
    max_diff = torch.max(torch.abs(output_standard - output_flash)).item()
    
    # Memory stats
    memory_stats = flash_attn.get_memory_stats(seq_len)
    
    print(f"\nSequence length: {seq_len}")
    print(f"  Time - Standard: {time_standard:.4f}s, Flash: {time_flash:.4f}s")
    print(f"  Max difference: {max_diff:.2e}")
    print(f"  Memory reduction: {memory_stats['reduction']:.1f}x")
    print(f"  Outputs match: {max_diff < 1e-4}")

## 📈 Scaling Analysis: Long Sequences

FlashAttention's real power becomes apparent with very long sequences.

In [None]:
def analyze_long_sequence_scaling():
    """
    Analyze how FlashAttention scales with very long sequences.
    """
    # Sequence lengths to test (some very long)
    seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
    block_size = 128
    head_dim = 64
    
    results = {
        'seq_lengths': [],
        'naive_memory_gb': [],
        'flash_memory_gb': [],
        'memory_reduction': [],
        'naive_feasible': [],
        'flash_feasible': []
    }
    
    # Assume we have 24GB GPU memory
    gpu_memory_gb = 24
    bytes_per_float = 4
    
    for seq_len in seq_lengths:
        # Naive attention memory (just for attention matrix)
        naive_elements = seq_len * seq_len
        naive_memory_gb = (naive_elements * bytes_per_float) / (1024**3)
        
        # FlashAttention memory (block-wise)
        flash_elements = block_size * block_size
        flash_memory_gb = (flash_elements * bytes_per_float) / (1024**3)
        
        # Check feasibility (simplified - just attention matrix)
        naive_feasible = naive_memory_gb < gpu_memory_gb * 0.5  # Leave room for other tensors
        flash_feasible = flash_memory_gb < gpu_memory_gb * 0.5
        
        results['seq_lengths'].append(seq_len)
        results['naive_memory_gb'].append(naive_memory_gb)
        results['flash_memory_gb'].append(flash_memory_gb)
        results['memory_reduction'].append(naive_memory_gb / flash_memory_gb)
        results['naive_feasible'].append(naive_feasible)
        results['flash_feasible'].append(flash_feasible)
    
    return results

# Analyze scaling
scaling_results = analyze_long_sequence_scaling()

# Visualize scaling results
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

seq_lengths = scaling_results['seq_lengths']

# 1. Memory usage in GB
ax1.semilogy(seq_lengths, scaling_results['naive_memory_gb'], 'o-', 
             label='Naive Attention', linewidth=2, markersize=8)
ax1.semilogy(seq_lengths, scaling_results['flash_memory_gb'], 's-', 
             label='FlashAttention', linewidth=2, markersize=8)
ax1.axhline(y=12, color='red', linestyle='--', alpha=0.7, label='50% of 24GB GPU')
ax1.set_xlabel('Sequence Length')
ax1.set_ylabel('Memory Usage (GB)')
ax1.set_title('Memory Usage vs Sequence Length')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Memory reduction factor
ax2.semilogx(seq_lengths, scaling_results['memory_reduction'], 'o-', 
             color='green', linewidth=2, markersize=8)
ax2.set_xlabel('Sequence Length')
ax2.set_ylabel('Memory Reduction Factor')
ax2.set_title('Memory Reduction (Log Scale)')
ax2.grid(True, alpha=0.3)

# Add annotations for key points
for i, (x, y) in enumerate(zip(seq_lengths, scaling_results['memory_reduction'])):
    if i % 2 == 0:  # Annotate every other point
        ax2.annotate(f'{y:.0f}x', (x, y), textcoords="offset points", 
                    xytext=(0,10), ha='center', fontsize=10)

# 3. Feasibility analysis
feasible_data = []
colors = []
labels = []

for i, seq_len in enumerate(seq_lengths):
    if scaling_results['naive_feasible'][i] and scaling_results['flash_feasible'][i]:
        feasible_data.append((seq_len, 'Both Feasible', 'green'))
    elif scaling_results['flash_feasible'][i]:
        feasible_data.append((seq_len, 'Flash Only', 'orange'))
    else:
        feasible_data.append((seq_len, 'Neither Feasible', 'red'))

# Create feasibility plot
for i, (seq_len, status, color) in enumerate(feasible_data):
    ax3.bar(i, seq_len, color=color, alpha=0.7, label=status if status not in labels else "")
    if status not in labels:
        labels.append(status)

ax3.set_xlabel('Sequence Index')
ax3.set_ylabel('Sequence Length')
ax3.set_title('Feasibility on 24GB GPU')
ax3.set_xticks(range(len(seq_lengths)))
ax3.set_xticklabels([f'{s//1024}K' for s in seq_lengths])
ax3.legend()
ax3.grid(True, alpha=0.3, axis='y')

# 4. Maximum achievable sequence length
# Theoretical analysis
memory_budgets = [8, 16, 24, 40, 80]  # Different GPU memory sizes in GB
max_seq_naive = []
max_seq_flash = []

for budget_gb in memory_budgets:
    budget_elements = (budget_gb * 0.5 * 1024**3) / 4  # 50% of memory, 4 bytes per float
    
    # Naive: N² elements
    max_naive = int(np.sqrt(budget_elements))
    
    # Flash: block_size² elements (can handle any sequence length)
    max_flash = 100000  # Essentially unlimited for practical purposes
    
    max_seq_naive.append(max_naive)
    max_seq_flash.append(min(max_flash, 100000))  # Cap for visualization

ax4.plot(memory_budgets, max_seq_naive, 'o-', label='Naive Attention', linewidth=2)
ax4.axhline(y=100000, color='green', linestyle='-', linewidth=2, label='FlashAttention (>100K)')
ax4.set_xlabel('GPU Memory (GB)')
ax4.set_ylabel('Max Sequence Length')
ax4.set_title('Maximum Achievable Sequence Length')
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_ylim(0, 50000)

plt.tight_layout()
plt.show()

# Print summary table
print("\n📊 Long Sequence Scaling Analysis:")
print("=" * 80)
print(f"{'Seq Len':<10} {'Naive GB':<12} {'Flash GB':<12} {'Reduction':<12} {'Feasible':<15}")
print("=" * 80)

for i, seq_len in enumerate(scaling_results['seq_lengths']):
    naive_gb = scaling_results['naive_memory_gb'][i]
    flash_gb = scaling_results['flash_memory_gb'][i]
    reduction = scaling_results['memory_reduction'][i]
    
    if scaling_results['naive_feasible'][i] and scaling_results['flash_feasible'][i]:
        feasible = "Both"
    elif scaling_results['flash_feasible'][i]:
        feasible = "Flash Only"
    else:
        feasible = "Neither"
    
    print(f"{seq_len:<10} {naive_gb:<12.3f} {flash_gb:<12.6f} {reduction:<12.0f}x {feasible:<15}")

## 🎯 Practical Exercises

### Exercise 1: Implement Block-wise Attention
Implement a simplified version of block-wise attention computation.

In [None]:
def exercise_blockwise_attention():
    """
    Exercise: Implement and test block-wise attention.
    """
    print("🧪 Exercise: Block-wise Attention Implementation")
    print("=" * 60)
    
    def blockwise_attention_exercise(Q, K, V, block_size):
        """
        YOUR TASK: Complete this block-wise attention implementation.
        
        Hints:
        1. Process Q in blocks of size block_size
        2. For each Q block, process all K blocks
        3. Use online softmax to maintain numerical stability
        4. Accumulate results properly
        """
        batch_size, seq_len, head_dim = Q.shape
        scale = 1.0 / math.sqrt(head_dim)
        
        # Initialize output
        O = torch.zeros_like(Q)
        
        # TODO: Implement block-wise processing
        # For now, we'll provide a working solution
        
        for i in range(0, seq_len, block_size):
            i_end = min(i + block_size, seq_len)
            Q_block = Q[:, i:i_end, :]
            
            # Initialize block statistics
            block_max = torch.full((batch_size, i_end - i, 1), -float('inf'), device=Q.device)
            block_sum = torch.zeros((batch_size, i_end - i, 1), device=Q.device)
            block_output = torch.zeros_like(Q_block)
            
            for j in range(0, seq_len, block_size):
                j_end = min(j + block_size, seq_len)
                K_block = K[:, j:j_end, :]
                V_block = V[:, j:j_end, :]
                
                # Compute attention scores for this block pair
                scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) * scale
                
                # Update statistics (simplified online softmax)
                block_max_new = torch.maximum(block_max, scores.max(dim=-1, keepdim=True)[0])
                
                # Compute softmax for this block
                exp_scores = torch.exp(scores - block_max_new)
                
                # Update running sum
                correction = torch.exp(block_max - block_max_new)
                block_sum = block_sum * correction + exp_scores.sum(dim=-1, keepdim=True)
                
                # Update output
                block_output = block_output * correction + torch.matmul(exp_scores, V_block)
                block_max = block_max_new
            
            # Normalize the block output
            O[:, i:i_end, :] = block_output / block_sum
        
        return O
    
    # Test the implementation
    seq_len = 256
    head_dim = 64
    batch_size = 2
    block_size = 32
    
    Q = torch.randn(batch_size, seq_len, head_dim)
    K = torch.randn(batch_size, seq_len, head_dim)
    V = torch.randn(batch_size, seq_len, head_dim)
    
    # Compare with standard attention
    standard_output, _ = naive_attention(Q, K, V)
    blockwise_output = blockwise_attention_exercise(Q, K, V, block_size)
    
    # Measure accuracy
    max_diff = torch.max(torch.abs(standard_output - blockwise_output)).item()
    mean_diff = torch.mean(torch.abs(standard_output - blockwise_output)).item()
    
    print(f"✅ Block-wise attention test:")
    print(f"   Sequence length: {seq_len}")
    print(f"   Block size: {block_size}")
    print(f"   Max difference: {max_diff:.2e}")
    print(f"   Mean difference: {mean_diff:.2e}")
    print(f"   Results match: {max_diff < 1e-3}")
    
    # Test different block sizes
    block_sizes = [16, 32, 64, 128]
    differences = []
    
    print(f"\n📊 Block size analysis:")
    for bs in block_sizes:
        if bs <= seq_len:
            output = blockwise_attention_exercise(Q, K, V, bs)
            diff = torch.max(torch.abs(standard_output - output)).item()
            differences.append(diff)
            print(f"   Block size {bs}: max diff = {diff:.2e}")
    
    return block_sizes[:len(differences)], differences

# Run the exercise
block_sizes, differences = exercise_blockwise_attention()

# Visualize block size effects
if len(differences) > 0:
    plt.figure(figsize=(10, 6))
    plt.semilogy(block_sizes, differences, 'o-', linewidth=2, markersize=8)
    plt.xlabel('Block Size')
    plt.ylabel('Max Absolute Difference')
    plt.title('Numerical Accuracy vs Block Size')
    plt.grid(True, alpha=0.3)
    plt.axhline(y=1e-6, color='red', linestyle='--', alpha=0.7, label='Target Precision')
    plt.legend()
    plt.show()

## 💡 Key Takeaways

### FlashAttention Advantages:
1. **Memory Efficiency**: Reduces memory complexity from O(N²) to O(N)
2. **Exact Computation**: Mathematically equivalent to standard attention
3. **Speed Improvements**: 2-4x faster on modern GPUs
4. **Long Sequences**: Enables training on sequences 8x longer
5. **Hardware Awareness**: Optimized for GPU memory hierarchy

### Core Innovations:
1. **Block-wise Computation**: Process attention in blocks that fit in SRAM
2. **Online Softmax**: Compute softmax incrementally without storing full matrix
3. **Recomputation**: Trade computation for memory in backward pass
4. **IO Optimization**: Minimize slow HBM memory access

### When to Use FlashAttention:
- **Long Sequences**: When sequence length > 1024
- **Memory Constraints**: When standard attention causes OOM
- **Large Models**: When every bit of memory efficiency matters
- **Production Deployment**: For improved inference speed

### Implementation Considerations:
1. **Block Size**: Balance between memory and computation efficiency
2. **Hardware**: Different optimal block sizes for different GPUs
3. **Precision**: FP16 vs FP32 considerations
4. **Backward Pass**: Recomputation strategy affects memory vs speed trade-off

## 🚀 Next Steps

1. **Study FlashAttention-2**: Latest improvements and optimizations
2. **Explore Ring Attention**: Distributed attention for extremely long sequences
3. **Try PagedAttention**: Attention with paging for serving
4. **Implement in Practice**: Use with actual transformer models
5. **Hardware Optimization**: Learn about kernel optimization and CUDA programming

**FlashAttention has fundamentally changed how we think about attention computation, making long-context models practical and efficient. It's an essential technique for modern transformer architectures!** 🎯