# Day 32: FlashAttention - Part 4

In this notebook, we'll explore FlashAttention, an algorithm that optimizes the attention computation in transformer models, reducing both memory usage and computation time.

## Overview

1. Understanding the attention bottleneck
2. How FlashAttention works
3. Implementing a simplified version of FlashAttention
4. Measuring performance improvements

## 1. Understanding the Attention Bottleneck

The standard attention computation in transformer models faces two main challenges:

1. **Memory Bottleneck**: Storing the full attention matrix (N×N) in high-precision
2. **I/O Bound**: Multiple reads/writes to high-bandwidth memory (HBM)

These challenges limit the maximum sequence length and batch size that can be processed efficiently.

In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import gc

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Standard Attention Implementation

Let's first implement the standard attention computation to understand the baseline.

In [None]:
def standard_attention(q, k, v, scale=None, mask=None):
    """Standard attention computation.
    
    Args:
        q: Query tensor of shape (batch_size, seq_len, head_dim)
        k: Key tensor of shape (batch_size, seq_len, head_dim)
        v: Value tensor of shape (batch_size, seq_len, head_dim)
        scale: Scaling factor for attention scores
        mask: Optional attention mask
        
    Returns:
        Output tensor of shape (batch_size, seq_len, head_dim)
    """
    # Get dimensions
    batch_size, seq_len, head_dim = q.shape
    
    # Set scale if not provided
    if scale is None:
        scale = 1.0 / np.sqrt(head_dim)
    
    # Compute attention scores: (batch_size, seq_len, seq_len)
    scores = torch.bmm(q, k.transpose(1, 2)) * scale
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax to get attention weights
    attn_weights = torch.nn.functional.softmax(scores, dim=-1)
    
    # Compute output: (batch_size, seq_len, head_dim)
    output = torch.bmm(attn_weights, v)
    
    return output

## 3. How FlashAttention Works

FlashAttention optimizes attention computation through three key techniques:

1. **Block-wise Processing**: Divides matrices into blocks that fit in fast SRAM
2. **Operation Fusion**: Combines multiple operations to reduce memory access
3. **Recomputation**: Trades additional computation for reduced memory usage

This approach significantly reduces the memory I/O cost, which is often the bottleneck in attention computation.

## 4. Implementing a Simplified FlashAttention

Let's implement a simplified version of FlashAttention to demonstrate its core concepts. Note that this is a pedagogical implementation and not optimized for production use.

In [None]:
def simplified_flash_attention(q, k, v, block_size=64, scale=None, mask=None):
    """Simplified implementation of FlashAttention.
    
    Args:
        q: Query tensor of shape (batch_size, seq_len, head_dim)
        k: Key tensor of shape (batch_size, seq_len, head_dim)
        v: Value tensor of shape (batch_size, seq_len, head_dim)
        block_size: Size of blocks for tiled computation
        scale: Scaling factor for attention scores
        mask: Optional attention mask
        
    Returns:
        Output tensor of shape (batch_size, seq_len, head_dim)
    """
    # Get dimensions
    batch_size, seq_len, head_dim = q.shape
    
    # Set scale if not provided
    if scale is None:
        scale = 1.0 / np.sqrt(head_dim)
    
    # Initialize output and softmax normalization terms
    output = torch.zeros_like(q)
    normalizer = torch.zeros((batch_size, seq_len, 1), device=q.device)
    
    # Process in blocks
    for i in range(0, seq_len, block_size):
        # Current block size (might be smaller at the end)
        current_block_size = min(block_size, seq_len - i)
        
        # Extract query block: (batch_size, block_size, head_dim)
        q_block = q[:, i:i+current_block_size, :]
        
        # Initialize block output and normalization
        block_output = torch.zeros_like(q_block)
        block_normalizer = torch.zeros((batch_size, current_block_size, 1), device=q.device)
        
        # Process key-value blocks
        for j in range(0, seq_len, block_size):
            # Current key-value block size
            current_kv_block_size = min(block_size, seq_len - j)
            
            # Extract key and value blocks
            k_block = k[:, j:j+current_kv_block_size, :]
            v_block = v[:, j:j+current_kv_block_size, :]
            
            # Compute attention scores for this block: (batch_size, block_size, kv_block_size)
            scores = torch.bmm(q_block, k_block.transpose(1, 2)) * scale
            
            # Apply mask if provided
            if mask is not None:
                block_mask = mask[:, i:i+current_block_size, j:j+current_kv_block_size]
                scores = scores.masked_fill(block_mask == 0, -1e9)
            
            # Apply softmax approximation (exp only, normalize later)
            attn_weights = torch.exp(scores)
            
            # Update block output and normalization term
            block_output += torch.bmm(attn_weights, v_block)
            block_normalizer += attn_weights.sum(dim=-1, keepdim=True)
        
        # Normalize block output
        block_output = block_output / (block_normalizer + 1e-6)
        
        # Update output
        output[:, i:i+current_block_size, :] = block_output
    
    return output

## 5. Measuring Performance

Let's compare the performance of standard attention and our simplified FlashAttention implementation.

In [None]:
def measure_attention_performance(batch_size, seq_len, head_dim, num_runs=5):
    """Measure performance of standard and FlashAttention."""
    # Create random query, key, value tensors
    q = torch.randn(batch_size, seq_len, head_dim, device=device)
    k = torch.randn(batch_size, seq_len, head_dim, device=device)
    v = torch.randn(batch_size, seq_len, head_dim, device=device)
    
    # Measure standard attention time
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    standard_times = []
    standard_memory = []
    
    for _ in range(num_runs):
        # Clear cache
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Record memory before
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            mem_before = torch.cuda.memory_allocated()
        
        # Time standard attention
        start_time = time.time()
        output_standard = standard_attention(q, k, v)
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        standard_times.append(time.time() - start_time)
        
        # Record memory usage
        if torch.cuda.is_available():
            mem_after = torch.cuda.max_memory_allocated()
            standard_memory.append(mem_after - mem_before)
    
    # Measure FlashAttention time
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    flash_times = []
    flash_memory = []
    
    for _ in range(num_runs):
        # Clear cache
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Record memory before
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            mem_before = torch.cuda.memory_allocated()
        
        # Time FlashAttention
        start_time = time.time()
        output_flash = simplified_flash_attention(q, k, v)
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        flash_times.append(time.time() - start_time)
        
        # Record memory usage
        if torch.cuda.is_available():
            mem_after = torch.cuda.max_memory_allocated()
            flash_memory.append(mem_after - mem_before)
    
    # Calculate average times and memory usage
    avg_standard_time = sum(standard_times) / len(standard_times)
    avg_flash_time = sum(flash_times) / len(flash_times)
    
    avg_standard_memory = sum(standard_memory) / len(standard_memory) if standard_memory else 0
    avg_flash_memory = sum(flash_memory) / len(flash_memory) if flash_memory else 0
    
    # Check correctness (outputs should be similar)
    if torch.cuda.is_available():
        error = torch.abs(output_standard - output_flash).mean().item()
    else:
        error = 0  # Skip error check if not on CUDA
    
    return {
        "standard_time": avg_standard_time,
        "flash_time": avg_flash_time,
        "speedup": avg_standard_time / avg_flash_time,
        "standard_memory": avg_standard_memory / (1024 * 1024),  # Convert to MB
        "flash_memory": avg_flash_memory / (1024 * 1024),  # Convert to MB
        "memory_reduction": avg_standard_memory / avg_flash_memory if avg_flash_memory > 0 else 0,
        "error": error
    }

In [None]:
# Measure performance for a specific configuration
results = measure_attention_performance(
    batch_size=4,
    seq_len=1024,
    head_dim=64
)

print("Performance Comparison:")
print(f"Standard Attention Time: {results['standard_time']:.4f} seconds")
print(f"FlashAttention Time: {results['flash_time']:.4f} seconds")
print(f"Speedup: {results['speedup']:.2f}x")
print(f"\nStandard Attention Memory: {results['standard_memory']:.2f} MB")
print(f"FlashAttention Memory: {results['flash_memory']:.2f} MB")
print(f"Memory Reduction: {results['memory_reduction']:.2f}x")
print(f"\nOutput Error: {results['error']:.6f}")

## 6. Scaling with Sequence Length

One of the key benefits of FlashAttention is its improved scaling with sequence length. Let's measure how performance changes as sequence length increases.

In [None]:
def measure_scaling_with_seq_len(batch_size, head_dim, seq_lengths):
    """Measure how performance scales with sequence length."""
    standard_times = []
    flash_times = []
    standard_memory = []
    flash_memory = []
    
    for seq_len in seq_lengths:
        print(f"Testing sequence length: {seq_len}")
        try:
            result = measure_attention_performance(batch_size, seq_len, head_dim)
            standard_times.append(result["standard_time"])
            flash_times.append(result["flash_time"])
            standard_memory.append(result["standard_memory"])
            flash_memory.append(result["flash_memory"])
        except RuntimeError as e:
            print(f"Error at sequence length {seq_len}: {e}")
            # If we run out of memory, stop the experiment
            break
    
    return standard_times, flash_times, standard_memory, flash_memory

In [None]:
# Test different sequence lengths
seq_lengths = [128, 256, 512, 1024, 2048]
standard_times, flash_times, standard_memory, flash_memory = measure_scaling_with_seq_len(
    batch_size=2,
    head_dim=64,
    seq_lengths=seq_lengths
)

In [None]:
# Plot the results
plt.figure(figsize=(12, 5))

# Plot execution time
plt.subplot(1, 2, 1)
plt.plot(seq_lengths[:len(standard_times)], standard_times, marker='o', label="Standard Attention")
plt.plot(seq_lengths[:len(flash_times)], flash_times, marker='s', label="FlashAttention")
plt.xlabel("Sequence Length")
plt.ylabel("Time (seconds)")
plt.title("Execution Time vs. Sequence Length")
plt.legend()
plt.grid(True, alpha=0.3)

# Plot memory usage
plt.subplot(1, 2, 2)
plt.plot(seq_lengths[:len(standard_memory)], standard_memory, marker='o', label="Standard Attention")
plt.plot(seq_lengths[:len(flash_memory)], flash_memory, marker='s', label="FlashAttention")
plt.xlabel("Sequence Length")
plt.ylabel("Memory Usage (MB)")
plt.title("Memory Usage vs. Sequence Length")
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. FlashAttention-2 Improvements

FlashAttention-2 further optimizes the algorithm with:
- Improved tiling strategies for different GPU architectures
- Online softmax algorithm for better numerical stability
- Optimized memory access patterns
- Support for different attention mask patterns

These improvements lead to even better performance, especially for longer sequences.

## 8. Using FlashAttention in PyTorch

FlashAttention is available in PyTorch through the `flash-attn` package. Let's see how to use it if it's installed.

In [None]:
# Try to import flash_attn
try:
    from flash_attn import flash_attn_func
    flash_attn_available = True
    print("FlashAttention is available!")
except ImportError:
    flash_attn_available = False
    print("FlashAttention is not installed. You can install it with:")
    print("pip install flash-attn")

In [None]:
# Example usage of FlashAttention if available
if flash_attn_available:
    # Create random query, key, value tensors
    batch_size = 2
    seq_len = 1024
    num_heads = 8
    head_dim = 64
    
    # FlashAttention expects shape (batch_size, seq_len, num_heads, head_dim)
    q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device)
    k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device)
    v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device)
    
    # Use FlashAttention
    output = flash_attn_func(q, k, v, causal=True)
    print(f"Output shape: {output.shape}")
else:
    print("Skipping FlashAttention example as it's not installed.")

## 9. FlashAttention in Hugging Face Transformers

Hugging Face Transformers also supports FlashAttention in some models. Here's how you can enable it:

In [None]:
# Example of using FlashAttention in Hugging Face Transformers
from transformers import AutoConfig, AutoModelForCausalLM

# This is just an example, it won't run without the flash_attn package
def load_model_with_flash_attn():
    config = AutoConfig.from_pretrained("gpt2")
    config.use_flash_attention_2 = True  # Enable FlashAttention-2
    
    model = AutoModelForCausalLM.from_pretrained(
        "gpt2",
        config=config,
        torch_dtype=torch.float16  # FlashAttention works best with float16
    )
    
    return model

print("Note: This is just example code and won't run without the flash_attn package.")

## Conclusion

In this notebook, we've explored FlashAttention, an algorithm that optimizes attention computation in transformer models. We've implemented a simplified version to demonstrate its core concepts and measured its performance benefits.

Key takeaways:

1. FlashAttention reduces memory usage and computation time by using block-wise processing and operation fusion
2. The algorithm scales better with sequence length compared to standard attention
3. FlashAttention-2 further improves performance with optimized tiling and memory access patterns
4. These optimizations enable processing longer sequences and larger batch sizes, which is critical for efficient LLM inference

For production use, it's recommended to use the optimized implementations available in libraries like `flash-attn` or through Hugging Face Transformers' built-in support.