# Day 32: Paged Attention - Part 3

In this notebook, we'll explore paged attention, a memory management technique that significantly improves the efficiency of KV caching for large language models.

## Overview

1. Understanding paged attention
2. The problem with traditional KV caching
3. Implementing a simplified paged attention mechanism
4. Comparing memory efficiency

## 1. Understanding Paged Attention

Paged attention is a technique introduced by vLLM that applies virtual memory concepts to KV caching in transformer models. It addresses the memory fragmentation and inefficient memory utilization issues of traditional KV caching.

Key concepts:
1. **Pages**: Fixed-size blocks of memory for storing KV cache
2. **Block Table**: Maps logical positions to physical memory locations
3. **Non-contiguous Allocation**: Allows flexible memory management

In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import gc

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. The Problem with Traditional KV Caching

Traditional KV caching allocates contiguous memory blocks for each sequence, which leads to several issues:

1. **Memory Fragmentation**: When sequences complete at different times, they leave gaps in memory
2. **Inefficient Memory Utilization**: These gaps cannot be easily reused
3. **Limited Concurrent Requests**: The number of sequences that can be processed simultaneously is limited

In [None]:
# Simulate traditional KV caching memory allocation
def simulate_traditional_kv_cache(num_sequences, max_seq_length, hidden_dim):
    """Simulate traditional KV cache memory allocation."""
    # Initialize memory usage tracking
    total_allocated = 0
    wasted_memory = 0
    active_sequences = []
    
    # Simulate sequence processing with random lengths
    for i in range(num_sequences):
        # Random sequence length between 10 and max_seq_length
        seq_length = np.random.randint(10, max_seq_length + 1)
        
        # Memory required for this sequence
        memory_required = seq_length * hidden_dim
        
        # Allocate memory (always allocate max_seq_length for traditional approach)
        allocated_memory = max_seq_length * hidden_dim
        total_allocated += allocated_memory
        
        # Calculate wasted memory
        wasted = allocated_memory - memory_required
        wasted_memory += wasted
        
        # Add to active sequences
        active_sequences.append({
            "id": i,
            "length": seq_length,
            "allocated": allocated_memory,
            "wasted": wasted
        })
        
        # Randomly complete some sequences
        if i > 0 and np.random.random() < 0.3:
            # Remove a random active sequence
            idx = np.random.randint(0, len(active_sequences))
            active_sequences.pop(idx)
    
    # Calculate efficiency
    efficiency = 1 - (wasted_memory / total_allocated) if total_allocated > 0 else 0
    
    return {
        "total_allocated": total_allocated,
        "wasted_memory": wasted_memory,
        "efficiency": efficiency,
        "active_sequences": len(active_sequences)
    }

In [None]:
# Run the simulation
traditional_results = simulate_traditional_kv_cache(
    num_sequences=100,
    max_seq_length=1024,
    hidden_dim=64
)

print("Traditional KV Cache Simulation Results:")
print(f"Total allocated memory: {traditional_results['total_allocated']:,}")
print(f"Wasted memory: {traditional_results['wasted_memory']:,}")
print(f"Memory efficiency: {traditional_results['efficiency']:.2%}")
print(f"Active sequences at end: {traditional_results['active_sequences']}")

## 3. Implementing a Simplified Paged Attention Mechanism

Now, let's implement a simplified version of paged attention to demonstrate its memory efficiency benefits.

In [None]:
class PagedAttentionSimulator:
    def __init__(self, page_size, num_pages, hidden_dim):
        """Initialize a paged attention simulator.
        
        Args:
            page_size: Number of tokens per page
            num_pages: Total number of pages in memory
            hidden_dim: Hidden dimension size
        """
        self.page_size = page_size
        self.num_pages = num_pages
        self.hidden_dim = hidden_dim
        
        # Initialize memory pages
        self.pages = [None] * num_pages  # None means the page is free
        
        # Block tables for each sequence
        self.block_tables = {}
        
        # Statistics
        self.total_tokens_stored = 0
        self.total_pages_allocated = 0
    
    def allocate_sequence(self, seq_id, seq_length):
        """Allocate memory for a new sequence."""
        # Calculate number of pages needed
        num_pages_needed = (seq_length + self.page_size - 1) // self.page_size
        
        # Find free pages
        free_pages = [i for i, page in enumerate(self.pages) if page is None]
        
        if len(free_pages) < num_pages_needed:
            return False  # Not enough free pages
        
        # Allocate pages
        allocated_pages = free_pages[:num_pages_needed]
        for page_idx in allocated_pages:
            self.pages[page_idx] = seq_id
        
        # Create block table
        self.block_tables[seq_id] = allocated_pages
        
        # Update statistics
        self.total_tokens_stored += seq_length
        self.total_pages_allocated += num_pages_needed
        
        return True
    
    def free_sequence(self, seq_id):
        """Free memory for a completed sequence."""
        if seq_id not in self.block_tables:
            return False
        
        # Free pages
        for page_idx in self.block_tables[seq_id]:
            self.pages[page_idx] = None
        
        # Remove block table
        del self.block_tables[seq_id]
        
        return True
    
    def get_memory_stats(self):
        """Get memory usage statistics."""
        total_capacity = self.num_pages * self.page_size * self.hidden_dim
        used_capacity = self.total_pages_allocated * self.page_size * self.hidden_dim
        wasted_capacity = used_capacity - self.total_tokens_stored * self.hidden_dim
        efficiency = 1 - (wasted_capacity / used_capacity) if used_capacity > 0 else 0
        
        return {
            "total_capacity": total_capacity,
            "used_capacity": used_capacity,
            "wasted_capacity": wasted_capacity,
            "efficiency": efficiency,
            "active_sequences": len(self.block_tables),
            "free_pages": self.pages.count(None),
            "total_pages": self.num_pages
        }

In [None]:
# Simulate paged attention memory allocation
def simulate_paged_attention(num_sequences, max_seq_length, hidden_dim, page_size=16):
    """Simulate paged attention memory allocation."""
    # Calculate total pages needed (with some extra capacity)
    total_tokens = num_sequences * max_seq_length * 0.6  # Assuming 60% average utilization
    num_pages = int((total_tokens + page_size - 1) // page_size * 1.2)  # 20% extra capacity
    
    # Initialize paged attention simulator
    simulator = PagedAttentionSimulator(page_size, num_pages, hidden_dim)
    
    # Simulate sequence processing
    active_seqs = []
    for i in range(num_sequences):
        # Random sequence length between 10 and max_seq_length
        seq_length = np.random.randint(10, max_seq_length + 1)
        
        # Try to allocate memory for this sequence
        success = simulator.allocate_sequence(i, seq_length)
        
        if success:
            active_seqs.append(i)
        
        # Randomly complete some sequences
        if len(active_seqs) > 0 and np.random.random() < 0.3:
            # Remove a random active sequence
            idx = np.random.randint(0, len(active_seqs))
            seq_id = active_seqs.pop(idx)
            simulator.free_sequence(seq_id)
    
    # Get memory statistics
    return simulator.get_memory_stats()

In [None]:
# Run the paged attention simulation
paged_results = simulate_paged_attention(
    num_sequences=100,
    max_seq_length=1024,
    hidden_dim=64,
    page_size=16
)

print("Paged Attention Simulation Results:")
print(f"Total capacity: {paged_results['total_capacity']:,}")
print(f"Used capacity: {paged_results['used_capacity']:,}")
print(f"Wasted capacity: {paged_results['wasted_capacity']:,}")
print(f"Memory efficiency: {paged_results['efficiency']:.2%}")
print(f"Active sequences at end: {paged_results['active_sequences']}")
print(f"Free pages: {paged_results['free_pages']} / {paged_results['total_pages']}")

## 4. Comparing Memory Efficiency

Now, let's compare the memory efficiency of traditional KV caching and paged attention across different scenarios.

In [None]:
def compare_memory_efficiency(num_sequences_list, max_seq_length, hidden_dim, page_size=16):
    """Compare memory efficiency of traditional KV caching and paged attention."""
    traditional_efficiency = []
    paged_efficiency = []
    
    for num_sequences in num_sequences_list:
        # Run traditional simulation
        trad_result = simulate_traditional_kv_cache(num_sequences, max_seq_length, hidden_dim)
        traditional_efficiency.append(trad_result["efficiency"])
        
        # Run paged attention simulation
        paged_result = simulate_paged_attention(num_sequences, max_seq_length, hidden_dim, page_size)
        paged_efficiency.append(paged_result["efficiency"])
    
    return traditional_efficiency, paged_efficiency

In [None]:
# Compare memory efficiency for different numbers of sequences
num_sequences_list = [10, 20, 50, 100, 200]
traditional_efficiency, paged_efficiency = compare_memory_efficiency(
    num_sequences_list,
    max_seq_length=1024,
    hidden_dim=64
)

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(num_sequences_list, [e * 100 for e in traditional_efficiency], marker='o', label="Traditional KV Cache")
plt.plot(num_sequences_list, [e * 100 for e in paged_efficiency], marker='s', label="Paged Attention")
plt.xlabel("Number of Sequences")
plt.ylabel("Memory Efficiency (%)")
plt.title("Memory Efficiency Comparison")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 5. Impact of Page Size

The page size is an important parameter in paged attention. Let's examine how it affects memory efficiency.

In [None]:
def analyze_page_size_impact(page_sizes, num_sequences, max_seq_length, hidden_dim):
    """Analyze the impact of page size on memory efficiency."""
    efficiency_results = []
    
    for page_size in page_sizes:
        result = simulate_paged_attention(num_sequences, max_seq_length, hidden_dim, page_size)
        efficiency_results.append(result["efficiency"])
    
    return efficiency_results

In [None]:
# Analyze the impact of page size
page_sizes = [4, 8, 16, 32, 64, 128]
efficiency_by_page_size = analyze_page_size_impact(
    page_sizes,
    num_sequences=100,
    max_seq_length=1024,
    hidden_dim=64
)

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(page_sizes, [e * 100 for e in efficiency_by_page_size], marker='o')
plt.xlabel("Page Size (tokens)")
plt.ylabel("Memory Efficiency (%)")
plt.title("Impact of Page Size on Memory Efficiency")
plt.grid(True, alpha=0.3)
plt.show()

## 6. Simulating Concurrent Requests

One of the key benefits of paged attention is supporting more concurrent requests. Let's simulate this scenario.

In [None]:
def simulate_concurrent_requests(max_memory, max_seq_length, hidden_dim, page_size=16):
    """Simulate how many concurrent requests can be handled with limited memory."""
    # Traditional approach
    trad_max_sequences = max_memory // (max_seq_length * hidden_dim)
    
    # Paged attention approach
    # Assuming average sequence length is 60% of max_seq_length
    avg_seq_length = max_seq_length * 0.6
    pages_per_seq = (avg_seq_length + page_size - 1) // page_size
    total_pages = max_memory // (page_size * hidden_dim)
    paged_max_sequences = total_pages // pages_per_seq
    
    return trad_max_sequences, paged_max_sequences

In [None]:
# Simulate concurrent requests with different memory sizes
memory_sizes = [1e6, 2e6, 5e6, 1e7, 2e7, 5e7]  # Different memory sizes
trad_concurrent = []
paged_concurrent = []

for memory_size in memory_sizes:
    trad, paged = simulate_concurrent_requests(
        max_memory=int(memory_size),
        max_seq_length=1024,
        hidden_dim=64
    )
    trad_concurrent.append(trad)
    paged_concurrent.append(paged)

# Plot the results
plt.figure(figsize=(10, 6))
plt.bar(range(len(memory_sizes)), trad_concurrent, width=0.4, label="Traditional KV Cache", align="edge")
plt.bar([x + 0.4 for x in range(len(memory_sizes))], paged_concurrent, width=0.4, label="Paged Attention", align="edge")
plt.xlabel("Memory Size")
plt.ylabel("Max Concurrent Sequences")
plt.title("Maximum Concurrent Sequences by Memory Size")
plt.xticks([x + 0.2 for x in range(len(memory_sizes))], [f"{int(m/1e6)}M" for m in memory_sizes])
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Conclusion

In this notebook, we've explored paged attention, a memory management technique that significantly improves the efficiency of KV caching for large language models.

Key takeaways:

1. Paged attention divides the KV cache into fixed-size pages, enabling non-contiguous memory allocation
2. This approach significantly reduces memory fragmentation and improves memory utilization
3. The choice of page size affects memory efficiency - smaller pages reduce internal fragmentation but increase overhead
4. Paged attention enables supporting more concurrent requests with the same amount of memory

These benefits make paged attention a critical optimization for deploying large language models in production, especially for serving multiple users simultaneously.