# Chapter 6: Advanced Inference Optimization

## 🎯 Learning Objectives

By the end of this chapter, you will:
- **Master vLLM architecture** and PagedAttention algorithm
- **Understand continuous batching** vs static batching tradeoffs
- **Implement KV cache optimization** strategies
- **Build production inference systems** with advanced optimizations
- **Analyze inference bottlenecks** and optimization opportunities

---

## 🚀 The Inference Challenge in Production LLMs

### **Why Inference Optimization Matters**

LLM inference in production faces unique challenges:

#### **Scale Requirements**
- **ChatGPT**: Serves 100M+ requests daily
- **Claude**: Handles millions of conversations
- **GitHub Copilot**: Processes billions of code completions

#### **Performance Demands**
- **Latency**: < 100ms time-to-first-token for interactive use
- **Throughput**: 1000+ requests/second per GPU
- **Cost**: $0.001 per 1K tokens or lower
- **Availability**: 99.9% uptime requirements

#### **Technical Constraints**
- **Memory bound**: KV cache grows with sequence length
- **Variable length**: Requests have diverse output lengths
- **Batching complexity**: How to group requests efficiently?
- **Hardware utilization**: Keep expensive GPUs busy

---

## 🧠 Traditional Inference Limitations

### **Static Batching Problems**

Traditional inference uses **static batching** with severe limitations:

#### **The Batch Straggler Problem**
```
Request A: "Hello" → "Hello world!" (3 tokens, 50ms)
Request B: "Explain" → "Explain quantum physics..." (200 tokens, 2000ms)
Request C: "Hi" → "Hi there!" (2 tokens, 30ms)

Static Batch Processing:
┌─────────────────────────────────────────────────────┐
│ All requests wait for slowest (Request B: 2000ms)  │
│ GPU utilization drops as requests complete early   │
│ Memory allocated for max length across entire batch│
└─────────────────────────────────────────────────────┘
```

#### **Memory Waste in Static Batching**
```python
# Static batching allocates for worst case
batch_size = 8
max_seq_len = 2048  # Must handle longest possible sequence
hidden_dim = 4096

# KV cache allocation (even for short sequences!)
kv_cache_memory = batch_size * max_seq_len * hidden_dim * 2 * num_layers * 2  # K + V
# = 8 * 2048 * 4096 * 2 * 32 * 2 = 8.6 GB per batch!
```

#### **GPU Underutilization**
- **Padding overhead**: Short sequences padded to max length
- **Idle compute**: GPU cores sit idle waiting for slowest sequence
- **Memory fragmentation**: Allocated but unused memory

---

## 🌟 vLLM: The Inference Revolution

### **Core Innovations**

**vLLM** ("Very Large Language Model" inference) introduces three breakthrough concepts:

#### **1. Continuous Batching**
- **Dynamic requests**: Add/remove requests during generation
- **No batch stragglers**: Completed requests immediately replaced
- **High utilization**: GPU stays busy throughout generation

#### **2. PagedAttention**
- **Virtual memory**: Borrows concepts from OS memory management
- **Block allocation**: KV cache stored in fixed-size blocks
- **Zero fragmentation**: No memory waste from padding

#### **3. Optimized CUDA Kernels**
- **Fused operations**: Combine multiple GPU operations
- **Memory coalescing**: Optimize memory access patterns
- **Kernel specialization**: Different kernels for different scenarios

### **Performance Impact**
- **2-24x higher throughput** than static batching
- **55% lower latency** for interactive workloads
- **90% GPU utilization** vs 30-40% with static batching

---

## 🔍 PagedAttention Deep Dive

### **The Memory Management Revolution**

PagedAttention treats **attention computation like virtual memory**:

#### **Traditional Attention Memory Layout**
```
Sequence 1: [████████████████████████████████] (allocated for max)
Sequence 2: [██████░░░░░░░░░░░░░░░░░░░░░░░░░░░░] (partially used)
Sequence 3: [███░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░░] (mostly wasted)

Problem: Fixed allocation leads to fragmentation
```

#### **PagedAttention Memory Layout**
```
Physical Memory Blocks:
Block 0: [████] Block 1: [████] Block 2: [████] Block 3: [████]

Sequence 1 mapping: Block 0 → Block 1 → Block 2
Sequence 2 mapping: Block 3 → Block 7
Sequence 3 mapping: Block 4

Benefit: Memory allocated exactly as needed
```

### **Block Management Algorithm**

```python
class PagedAttention:
    def __init__(self, block_size=16):
        self.block_size = block_size  # tokens per block
        self.physical_blocks = []     # actual memory blocks
        self.free_blocks = set()      # available blocks
        self.sequence_tables = {}     # virtual → physical mapping
    
    def allocate_sequence(self, seq_id, initial_length):
        blocks_needed = (initial_length + self.block_size - 1) // self.block_size
        allocated_blocks = []
        
        for _ in range(blocks_needed):
            if not self.free_blocks:
                return None  # Out of memory
            
            block_id = self.free_blocks.pop()
            allocated_blocks.append(block_id)
        
        self.sequence_tables[seq_id] = allocated_blocks
        return allocated_blocks
    
    def extend_sequence(self, seq_id, new_tokens):
        # Allocate additional blocks as sequence grows
        current_blocks = self.sequence_tables[seq_id]
        # ... implementation details
```

Let's implement a comprehensive vLLM-style inference system:

In [None]:
import torch
import torch.nn as nn
import numpy as np
import time
import threading
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
from collections import deque
import random
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor
import queue
import uuid
from contextlib import contextmanager
import warnings
warnings.filterwarnings('ignore')

@dataclass
class InferenceRequest:
    """Represents a single inference request with comprehensive metadata"""
    request_id: str
    prompt: str
    max_tokens: int
    temperature: float = 1.0
    
    # Timestamps
    arrival_time: float = field(default_factory=time.time)
    start_time: Optional[float] = None
    first_token_time: Optional[float] = None
    completion_time: Optional[float] = None
    
    # Generation state
    generated_tokens: List[int] = field(default_factory=list)
    is_complete: bool = False
    
    # Performance metrics
    tokens_generated: int = 0
    
    def get_latency(self) -> Optional[float]:
        """End-to-end latency"""
        if self.completion_time and self.arrival_time:
            return self.completion_time - self.arrival_time
        return None
    
    def get_time_to_first_token(self) -> Optional[float]:
        """Time to first token (TTFT)"""
        if self.first_token_time and self.arrival_time:
            return self.first_token_time - self.arrival_time
        return None
    
    def get_inter_token_latency(self) -> Optional[float]:
        """Average time between tokens"""
        if self.completion_time and self.first_token_time and self.tokens_generated > 1:
            return (self.completion_time - self.first_token_time) / (self.tokens_generated - 1)
        return None

class KVCacheBlock:
    """Fixed-size block for storing Key-Value cache"""
    
    def __init__(self, block_id: int, block_size: int = 16):
        self.block_id = block_id
        self.block_size = block_size  # number of tokens
        self.allocated_tokens = 0
        self.sequence_id: Optional[str] = None
        self.is_free = True
        
        # Simulated KV data (in practice, this would be actual tensors)
        self.k_cache = None  # [block_size, num_heads, head_dim]
        self.v_cache = None  # [block_size, num_heads, head_dim]
    
    def allocate(self, sequence_id: str, num_tokens: int) -> bool:
        """Allocate tokens in this block"""
        if self.allocated_tokens + num_tokens <= self.block_size:
            self.allocated_tokens += num_tokens
            self.sequence_id = sequence_id
            self.is_free = False
            return True
        return False
    
    def deallocate(self):
        """Free this block"""
        self.allocated_tokens = 0
        self.sequence_id = None
        self.is_free = True
        self.k_cache = None
        self.v_cache = None
    
    def get_utilization(self) -> float:
        """Get block utilization percentage"""
        return self.allocated_tokens / self.block_size

class PagedAttentionManager:
    """
    Implementation of PagedAttention memory management
    
    Educational Focus:
    This class demonstrates the core innovation of vLLM:
    treating attention computation like virtual memory management.
    """
    
    def __init__(self, total_blocks: int = 1000, block_size: int = 16):
        self.total_blocks = total_blocks
        self.block_size = block_size
        
        # Initialize physical blocks
        self.physical_blocks = [
            KVCacheBlock(block_id, block_size) 
            for block_id in range(total_blocks)
        ]
        
        # Free block management
        self.free_blocks = deque(range(total_blocks))
        
        # Sequence to blocks mapping (page table)
        self.sequence_page_tables: Dict[str, List[int]] = {}
        
        # Thread safety
        self.lock = threading.Lock()
    
    def allocate_sequence(self, sequence_id: str, initial_tokens: int) -> bool:
        """Allocate blocks for a new sequence"""
        with self.lock:
            blocks_needed = (initial_tokens + self.block_size - 1) // self.block_size
            
            if len(self.free_blocks) < blocks_needed:
                return False  # Out of memory
            
            # Allocate blocks
            allocated_blocks = []
            remaining_tokens = initial_tokens
            
            for _ in range(blocks_needed):
                block_id = self.free_blocks.popleft()
                block = self.physical_blocks[block_id]
                
                tokens_to_allocate = min(self.block_size, remaining_tokens)
                block.allocate(sequence_id, tokens_to_allocate)
                
                allocated_blocks.append(block_id)
                remaining_tokens -= tokens_to_allocate
            
            self.sequence_page_tables[sequence_id] = allocated_blocks
            return True
    
    def extend_sequence(self, sequence_id: str, additional_tokens: int) -> bool:
        """Extend sequence with additional tokens (during generation)"""
        with self.lock:
            if sequence_id not in self.sequence_page_tables:
                return False
            
            current_blocks = self.sequence_page_tables[sequence_id]
            remaining_tokens = additional_tokens
            
            # Try to fill existing blocks first
            for block_id in current_blocks:
                block = self.physical_blocks[block_id]
                available_space = block.block_size - block.allocated_tokens
                
                if available_space > 0:
                    tokens_to_add = min(available_space, remaining_tokens)
                    block.allocated_tokens += tokens_to_add
                    remaining_tokens -= tokens_to_add
                    
                    if remaining_tokens == 0:
                        return True
            
            # Need new blocks
            blocks_needed = (remaining_tokens + self.block_size - 1) // self.block_size
            
            if len(self.free_blocks) < blocks_needed:
                return False  # Out of memory
            
            # Allocate new blocks
            for _ in range(blocks_needed):
                block_id = self.free_blocks.popleft()
                block = self.physical_blocks[block_id]
                
                tokens_to_allocate = min(self.block_size, remaining_tokens)
                block.allocate(sequence_id, tokens_to_allocate)
                
                current_blocks.append(block_id)
                remaining_tokens -= tokens_to_allocate
            
            return True
    
    def deallocate_sequence(self, sequence_id: str):
        """Free all blocks for a completed sequence"""
        with self.lock:
            if sequence_id in self.sequence_page_tables:
                block_ids = self.sequence_page_tables[sequence_id]
                
                for block_id in block_ids:
                    self.physical_blocks[block_id].deallocate()
                    self.free_blocks.append(block_id)
                
                del self.sequence_page_tables[sequence_id]
    
    def get_memory_stats(self) -> Dict[str, Any]:
        """Get comprehensive memory utilization statistics"""
        with self.lock:
            free_blocks = len(self.free_blocks)
            used_blocks = self.total_blocks - free_blocks
            
            # Calculate fragmentation
            total_allocated_tokens = 0
            total_capacity_tokens = 0
            partially_filled_blocks = 0
            
            for block in self.physical_blocks:
                if not block.is_free:
                    total_allocated_tokens += block.allocated_tokens
                    total_capacity_tokens += block.block_size
                    
                    if block.allocated_tokens < block.block_size:
                        partially_filled_blocks += 1
            
            fragmentation_ratio = 0
            if total_capacity_tokens > 0:
                wasted_tokens = total_capacity_tokens - total_allocated_tokens
                fragmentation_ratio = wasted_tokens / total_capacity_tokens
            
            return {
                'total_blocks': self.total_blocks,
                'free_blocks': free_blocks,
                'used_blocks': used_blocks,
                'memory_utilization': used_blocks / self.total_blocks,
                'fragmentation_ratio': fragmentation_ratio,
                'partially_filled_blocks': partially_filled_blocks,
                'active_sequences': len(self.sequence_page_tables),
                'avg_blocks_per_sequence': used_blocks / len(self.sequence_page_tables) if self.sequence_page_tables else 0
            }
    
    def print_memory_status(self):
        """Print detailed memory status"""
        stats = self.get_memory_stats()
        
        print(f"\n💾 PagedAttention Memory Status:")
        print(f"   Total blocks: {stats['total_blocks']}")
        print(f"   Used blocks: {stats['used_blocks']} ({stats['memory_utilization']*100:.1f}%)")
        print(f"   Free blocks: {stats['free_blocks']}")
        print(f"   Active sequences: {stats['active_sequences']}")
        print(f"   Fragmentation: {stats['fragmentation_ratio']*100:.1f}%")
        print(f"   Avg blocks/sequence: {stats['avg_blocks_per_sequence']:.1f}")

print("✅ PagedAttention Implementation Complete!")
print("🧠 Ready to build advanced inference systems")

## 🔄 Continuous Batching Engine Implementation

Now let's implement the complete continuous batching system:

In [None]:
class ContinuousBatchingEngine:
    """
    Advanced continuous batching engine implementing vLLM-style optimizations
    
    Educational Focus:
    This implementation demonstrates the key algorithms behind modern
    high-throughput LLM inference systems used in production.
    """
    
    def __init__(self, 
                 max_batch_size: int = 32,
                 max_sequence_length: int = 2048,
                 block_size: int = 16,
                 total_blocks: int = 1000):
        
        self.max_batch_size = max_batch_size
        self.max_sequence_length = max_sequence_length
        
        # Memory management
        self.memory_manager = PagedAttentionManager(total_blocks, block_size)
        
        # Request management
        self.request_queue = queue.Queue()
        self.active_requests: Dict[str, InferenceRequest] = {}
        self.completed_requests: List[InferenceRequest] = []
        
        # Processing control
        self.is_running = False
        self.processing_thread: Optional[threading.Thread] = None
        
        # Performance tracking
        self.stats = {
            'total_requests': 0,
            'completed_requests': 0,
            'total_tokens_generated': 0,
            'processing_steps': 0,
            'avg_batch_size': 0,
            'throughput_history': [],
            'latency_history': [],
            'gpu_utilization_history': []
        }
        
        # Thread safety
        self.lock = threading.Lock()
    
    def add_request(self, request: InferenceRequest) -> bool:
        """Add a new inference request to the queue"""
        try:
            self.request_queue.put_nowait(request)
            with self.lock:
                self.stats['total_requests'] += 1
            return True
        except queue.Full:
            return False
    
    def start_processing(self):
        """Start the continuous processing loop"""
        if self.is_running:
            return
        
        self.is_running = True
        self.processing_thread = threading.Thread(target=self._processing_loop, daemon=True)
        self.processing_thread.start()
        
        print(f"🚀 Continuous batching engine started")
        print(f"   Max batch size: {self.max_batch_size}")
        print(f"   Total memory blocks: {self.memory_manager.total_blocks}")
    
    def stop_processing(self):
        """Stop the processing loop"""
        self.is_running = False
        if self.processing_thread:
            self.processing_thread.join(timeout=5.0)
        print("⏹️ Continuous batching engine stopped")
    
    def _processing_loop(self):
        """Main continuous processing loop"""
        
        print("🔄 Starting continuous processing loop...")
        
        while self.is_running:
            step_start_time = time.time()
            
            # 1. Add new requests to active batch
            self._admit_new_requests()
            
            # 2. Process current batch (if any active requests)
            if self.active_requests:
                self._process_active_batch()
            
            # 3. Remove completed requests
            self._remove_completed_requests()
            
            # 4. Update statistics
            self._update_statistics(step_start_time)
            
            # Small delay to prevent busy waiting
            time.sleep(0.001)  # 1ms
        
        print("🏁 Processing loop completed")
    
    def _admit_new_requests(self):
        """Admit new requests up to batch capacity"""
        
        while (len(self.active_requests) < self.max_batch_size and 
               not self.request_queue.empty()):
            
            try:
                request = self.request_queue.get_nowait()
            except queue.Empty:
                break
            
            # Estimate initial tokens (prompt processing)
            initial_tokens = len(request.prompt.split())  # Simplified tokenization
            
            # Try to allocate memory
            if self.memory_manager.allocate_sequence(request.request_id, 
                                                   initial_tokens + request.max_tokens):
                # Successfully allocated, add to active batch
                request.start_time = time.time()
                self.active_requests[request.request_id] = request
            else:
                # Out of memory, put back in queue
                self.request_queue.put(request)
                break  # No point trying more requests
    
    def _process_active_batch(self):
        """Process one step of the active batch"""
        
        # Simulate model forward pass for active requests
        # In practice, this would be actual transformer computation
        
        processing_start = time.time()
        
        # Simulate different processing times based on batch size
        batch_size = len(self.active_requests)
        base_time = 0.005  # 5ms base time
        batch_overhead = batch_size * 0.0005  # 0.5ms per request
        processing_time = base_time + batch_overhead
        
        time.sleep(processing_time)
        
        # Update each request in the batch
        for request in list(self.active_requests.values()):
            # Generate one token (simplified)
            new_token = random.randint(0, 1000)  # Dummy token
            request.generated_tokens.append(new_token)
            request.tokens_generated += 1
            
            # Record first token time
            if request.tokens_generated == 1 and request.first_token_time is None:
                request.first_token_time = time.time()
            
            # Extend KV cache for this token
            self.memory_manager.extend_sequence(request.request_id, 1)
            
            # Check completion conditions
            if (request.tokens_generated >= request.max_tokens or
                random.random() < 0.02):  # 2% chance to end naturally
                
                request.is_complete = True
                request.completion_time = time.time()
    
    def _remove_completed_requests(self):
        """Remove completed requests and free their memory"""
        
        completed_ids = []
        
        for request_id, request in self.active_requests.items():
            if request.is_complete:
                completed_ids.append(request_id)
                self.completed_requests.append(request)
                
                # Free memory
                self.memory_manager.deallocate_sequence(request_id)
                
                # Update stats
                with self.lock:
                    self.stats['completed_requests'] += 1
                    self.stats['total_tokens_generated'] += request.tokens_generated
        
        # Remove from active requests
        for request_id in completed_ids:
            del self.active_requests[request_id]
    
    def _update_statistics(self, step_start_time: float):
        """Update performance statistics"""
        
        step_duration = time.time() - step_start_time
        
        with self.lock:
            self.stats['processing_steps'] += 1
            
            # Track batch size over time
            current_batch_size = len(self.active_requests)
            total_batches = self.stats['processing_steps']
            self.stats['avg_batch_size'] = ((self.stats['avg_batch_size'] * (total_batches - 1) + 
                                           current_batch_size) / total_batches)
            
            # Calculate current throughput (tokens/second)
            if step_duration > 0:
                tokens_this_step = current_batch_size  # One token per request per step
                throughput = tokens_this_step / step_duration
                self.stats['throughput_history'].append(throughput)
                
                # Keep history manageable
                if len(self.stats['throughput_history']) > 1000:
                    self.stats['throughput_history'] = self.stats['throughput_history'][-500:]
            
            # GPU utilization estimate (simplified)
            gpu_util = min(100, current_batch_size / self.max_batch_size * 100)
            self.stats['gpu_utilization_history'].append(gpu_util)
            
            if len(self.stats['gpu_utilization_history']) > 1000:
                self.stats['gpu_utilization_history'] = self.stats['gpu_utilization_history'][-500:]
    
    def get_performance_stats(self) -> Dict[str, Any]:
        """Get comprehensive performance statistics"""
        
        with self.lock:
            stats = self.stats.copy()
        
        # Calculate additional metrics
        if self.completed_requests:
            latencies = [r.get_latency() for r in self.completed_requests if r.get_latency()]
            ttfts = [r.get_time_to_first_token() for r in self.completed_requests if r.get_time_to_first_token()]
            itls = [r.get_inter_token_latency() for r in self.completed_requests if r.get_inter_token_latency()]
            
            if latencies:
                stats['avg_latency_ms'] = np.mean(latencies) * 1000
                stats['p95_latency_ms'] = np.percentile(latencies, 95) * 1000
                stats['p99_latency_ms'] = np.percentile(latencies, 99) * 1000
            
            if ttfts:
                stats['avg_ttft_ms'] = np.mean(ttfts) * 1000
                stats['p95_ttft_ms'] = np.percentile(ttfts, 95) * 1000
            
            if itls:
                stats['avg_inter_token_latency_ms'] = np.mean(itls) * 1000
        
        # Current throughput
        if stats['throughput_history']:
            recent_throughput = stats['throughput_history'][-100:]  # Last 100 steps
            stats['current_throughput_tps'] = np.mean(recent_throughput)
        
        # GPU utilization
        if stats['gpu_utilization_history']:
            recent_gpu_util = stats['gpu_utilization_history'][-100:]
            stats['current_gpu_utilization'] = np.mean(recent_gpu_util)
        
        # Memory stats
        stats['memory_stats'] = self.memory_manager.get_memory_stats()
        
        return stats
    
    def print_status(self):
        """Print current engine status"""
        
        stats = self.get_performance_stats()
        
        print(f"\n🔄 Continuous Batching Engine Status:")
        print(f"   Active requests: {len(self.active_requests)}")
        print(f"   Queue depth: {self.request_queue.qsize()}")
        print(f"   Completed requests: {stats['completed_requests']}")
        print(f"   Average batch size: {stats['avg_batch_size']:.1f}")
        
        if 'current_throughput_tps' in stats:
            print(f"   Current throughput: {stats['current_throughput_tps']:.1f} tokens/sec")
        
        if 'current_gpu_utilization' in stats:
            print(f"   GPU utilization: {stats['current_gpu_utilization']:.1f}%")
        
        if 'avg_latency_ms' in stats:
            print(f"   Avg latency: {stats['avg_latency_ms']:.1f} ms")
            print(f"   P95 latency: {stats['p95_latency_ms']:.1f} ms")
        
        if 'avg_ttft_ms' in stats:
            print(f"   Avg TTFT: {stats['avg_ttft_ms']:.1f} ms")
        
        # Memory status
        mem_stats = stats['memory_stats']
        print(f"   Memory usage: {mem_stats['memory_utilization']*100:.1f}% "
              f"({mem_stats['used_blocks']}/{mem_stats['total_blocks']} blocks)")

print("✅ Continuous Batching Engine Implementation Complete!")
print("🚀 Ready for high-performance inference testing")

## 🧪 Comprehensive Inference Benchmark

Let's run a comprehensive experiment comparing static vs continuous batching:

In [None]:
def run_inference_comparison_experiment():
    """
    Comprehensive experiment comparing different inference strategies
    
    Educational Focus:
    This experiment demonstrates the practical benefits of advanced
    inference optimizations in realistic workload scenarios.
    """
    
    print("🧪 Starting Comprehensive Inference Comparison")
    print("=" * 60)
    
    # Generate realistic inference workload
    def generate_realistic_requests(num_requests: int = 50) -> List[InferenceRequest]:
        """Generate requests with realistic length distribution"""
        
        request_templates = [
            ("Write a short story about", (50, 150)),    # Short creative
            ("Explain the concept of", (100, 300)),      # Medium explanatory 
            ("Generate Python code for", (20, 100)),     # Short code
            ("Summarize this article", (50, 200)),       # Medium summary
            ("Translate to French:", (10, 50)),          # Short translation
            ("Write a detailed analysis", (200, 500)),   # Long analysis
            ("What is", (10, 30)),                       # Very short QA
            ("Create a comprehensive guide", (300, 800)) # Very long guide
        ]
        
        requests = []
        
        for i in range(num_requests):
            template, (min_tokens, max_tokens) = random.choice(request_templates)
            
            request = InferenceRequest(
                request_id=f"req_{i:03d}",
                prompt=f"{template} topic {i}",
                max_tokens=random.randint(min_tokens, max_tokens),
                temperature=random.uniform(0.7, 1.3)
            )
            
            requests.append(request)
        
        return requests
    
    # Generate workload
    test_requests = generate_realistic_requests(100)
    
    print(f"📊 Generated {len(test_requests)} realistic requests")
    token_distribution = [r.max_tokens for r in test_requests]
    print(f"   Token range: {min(token_distribution)} - {max(token_distribution)}")
    print(f"   Average tokens: {np.mean(token_distribution):.1f}")
    print(f"   Median tokens: {np.median(token_distribution):.1f}")
    
    results = {}
    
    # Test 1: Continuous Batching
    print(f"\n🔄 Testing Continuous Batching Engine...")
    
    continuous_engine = ContinuousBatchingEngine(
        max_batch_size=16,
        max_sequence_length=2048,
        block_size=16,
        total_blocks=2000
    )
    
    try:
        # Start processing
        continuous_start = time.time()
        continuous_engine.start_processing()
        
        # Add requests with realistic timing
        for i, request in enumerate(test_requests):
            # Create fresh request to avoid state issues
            fresh_request = InferenceRequest(
                request_id=f"continuous_{request.request_id}",
                prompt=request.prompt,
                max_tokens=request.max_tokens,
                temperature=request.temperature
            )
            
            continuous_engine.add_request(fresh_request)
            
            # Realistic request arrival pattern (Poisson-like)
            if i % 10 == 0:  # Print progress
                print(f"   Added {i+1}/{len(test_requests)} requests")
            
            # Small random delay between requests
            time.sleep(random.uniform(0.01, 0.05))
        
        print("   All requests submitted, waiting for completion...")
        
        # Wait for completion with status updates
        wait_start = time.time()
        max_wait_time = 30.0  # 30 seconds max wait
        
        while (len(continuous_engine.completed_requests) < len(test_requests) and 
               (time.time() - wait_start) < max_wait_time):
            
            time.sleep(1.0)  # Check every second
            
            # Status update every 5 seconds
            if int(time.time() - wait_start) % 5 == 0:
                completed = len(continuous_engine.completed_requests)
                active = len(continuous_engine.active_requests)
                print(f"   Progress: {completed}/{len(test_requests)} completed, {active} active")
        
        continuous_duration = time.time() - continuous_start
        continuous_engine.stop_processing()
        
        # Collect results
        continuous_stats = continuous_engine.get_performance_stats()
        results['continuous_batching'] = {
            'engine': continuous_engine,
            'duration': continuous_duration,
            'stats': continuous_stats,
            'completed_requests': len(continuous_engine.completed_requests)
        }
        
        print(f"   ✅ Continuous batching completed in {continuous_duration:.1f}s")
        print(f"   📊 Completed {len(continuous_engine.completed_requests)}/{len(test_requests)} requests")
        
    except Exception as e:
        print(f"   ❌ Continuous batching failed: {e}")
        results['continuous_batching'] = {'error': str(e)}
    
    # Test 2: Simulated Static Batching
    print(f"\n📦 Testing Static Batching (Simulated)...")
    
    try:
        static_start = time.time()
        
        # Simulate static batching behavior
        batch_size = 8
        static_completed = 0
        static_total_tokens = 0
        
        # Process in static batches
        for batch_start in range(0, len(test_requests), batch_size):
            batch_end = min(batch_start + batch_size, len(test_requests))
            batch = test_requests[batch_start:batch_end]
            
            # Find maximum tokens in batch (padding requirement)
            max_tokens_in_batch = max(r.max_tokens for r in batch)
            
            # Simulate processing time (slower due to padding overhead)
            # Base time + overhead for padding + stragglers
            base_time = len(batch) * 0.01  # 10ms per request
            padding_overhead = len(batch) * 0.005  # 5ms padding overhead per request
            straggler_penalty = max_tokens_in_batch * 0.0001  # Penalty for waiting for longest
            
            total_batch_time = base_time + padding_overhead + straggler_penalty
            time.sleep(total_batch_time)
            
            # All requests in batch complete together
            static_completed += len(batch)
            static_total_tokens += sum(r.max_tokens for r in batch)
            
            print(f"   Batch {batch_start//batch_size + 1}: {len(batch)} requests, "
                  f"max tokens: {max_tokens_in_batch}, time: {total_batch_time:.2f}s")
        
        static_duration = time.time() - static_start
        
        results['static_batching'] = {
            'duration': static_duration,
            'completed_requests': static_completed,
            'total_tokens': static_total_tokens,
            'throughput_tps': static_total_tokens / static_duration,
            'avg_latency_ms': static_duration * 1000 / static_completed  # Simplified
        }
        
        print(f"   ✅ Static batching completed in {static_duration:.1f}s")
        print(f"   📊 Processed {static_completed} requests")
        
    except Exception as e:
        print(f"   ❌ Static batching failed: {e}")
        results['static_batching'] = {'error': str(e)}
    
    return results

def analyze_inference_results(results: Dict[str, Any]):
    """
    Comprehensive analysis of inference comparison results
    """
    
    print("\n📊 INFERENCE PERFORMANCE ANALYSIS")
    print("=" * 60)
    
    if 'continuous_batching' not in results or 'static_batching' not in results:
        print("❌ Incomplete results - cannot perform comparison")
        return
    
    continuous = results['continuous_batching']
    static = results['static_batching']
    
    if 'error' in continuous or 'error' in static:
        print("❌ One or more tests failed")
        if 'error' in continuous:
            print(f"   Continuous batching error: {continuous['error']}")
        if 'error' in static:
            print(f"   Static batching error: {static['error']}")
        return
    
    # Extract metrics
    continuous_stats = continuous['stats']
    
    print(f"\n🔄 Continuous Batching Results:")
    print(f"   Duration: {continuous['duration']:.1f}s")
    print(f"   Completed: {continuous['completed_requests']} requests")
    print(f"   Tokens generated: {continuous_stats['total_tokens_generated']}")
    
    if 'current_throughput_tps' in continuous_stats:
        print(f"   Throughput: {continuous_stats['current_throughput_tps']:.1f} tokens/sec")
    
    if 'avg_latency_ms' in continuous_stats:
        print(f"   Avg latency: {continuous_stats['avg_latency_ms']:.1f} ms")
        print(f"   P95 latency: {continuous_stats['p95_latency_ms']:.1f} ms")
    
    if 'avg_ttft_ms' in continuous_stats:
        print(f"   Avg TTFT: {continuous_stats['avg_ttft_ms']:.1f} ms")
    
    print(f"   Avg batch size: {continuous_stats['avg_batch_size']:.1f}")
    
    print(f"\n📦 Static Batching Results:")
    print(f"   Duration: {static['duration']:.1f}s")
    print(f"   Completed: {static['completed_requests']} requests")
    print(f"   Tokens generated: {static['total_tokens']}")
    print(f"   Throughput: {static['throughput_tps']:.1f} tokens/sec")
    print(f"   Avg latency: {static['avg_latency_ms']:.1f} ms")
    
    # Performance comparison
    print(f"\n📈 Performance Comparison:")
    
    # Throughput improvement
    if 'current_throughput_tps' in continuous_stats:
        throughput_improvement = (continuous_stats['current_throughput_tps'] / 
                                static['throughput_tps'] - 1) * 100
        print(f"   Throughput improvement: {throughput_improvement:+.1f}%")
    
    # Latency improvement
    if 'avg_latency_ms' in continuous_stats:
        latency_improvement = (1 - continuous_stats['avg_latency_ms'] / 
                             static['avg_latency_ms']) * 100
        print(f"   Latency improvement: {latency_improvement:+.1f}%")
    
    # Total time comparison
    time_improvement = (1 - continuous['duration'] / static['duration']) * 100
    print(f"   Total time improvement: {time_improvement:+.1f}%")
    
    # Memory efficiency
    if 'memory_stats' in continuous_stats:
        mem_stats = continuous_stats['memory_stats']
        print(f"\n💾 Memory Efficiency (Continuous):")
        print(f"   Memory utilization: {mem_stats['memory_utilization']*100:.1f}%")
        print(f"   Fragmentation: {mem_stats['fragmentation_ratio']*100:.1f}%")
        print(f"   Active sequences: {mem_stats['active_sequences']}")
        
        if mem_stats['fragmentation_ratio'] < 0.1:
            print(f"   ✅ Excellent memory efficiency!")
        elif mem_stats['fragmentation_ratio'] < 0.2:
            print(f"   🟡 Good memory efficiency")
        else:
            print(f"   🔴 High fragmentation - consider optimization")
    
    # Overall assessment
    print(f"\n🎯 Overall Assessment:")
    
    if (continuous['completed_requests'] > static['completed_requests'] * 0.9 and 
        continuous['duration'] < static['duration']):
        print(f"   ✅ Continuous batching shows clear advantages")
        print(f"   📊 Recommendation: Use continuous batching for production")
    else:
        print(f"   🟡 Results are mixed - further analysis needed")
        print(f"   🔍 Consider workload characteristics and hardware constraints")

# Run the comprehensive inference comparison
print("🚀 Starting Comprehensive Inference Comparison Experiment")
inference_results = run_inference_comparison_experiment()
analyze_inference_results(inference_results)

## 🎯 Key Takeaways from Advanced Inference Optimization

### **Continuous Batching is Game-Changing**
- **2-24x throughput improvement** over static batching
- **50%+ latency reduction** for interactive workloads
- **90%+ GPU utilization** vs 30-40% with static batching
- **Essential for production** LLM serving

### **PagedAttention Eliminates Memory Waste**
- **Near-zero fragmentation** vs 50%+ in traditional systems
- **Dynamic allocation** based on actual sequence lengths
- **Block-based management** enables flexible memory reuse
- **Virtual memory concepts** applied to attention computation

### **Production Implications**
- **Cost reduction**: Serve more users per GPU
- **Better user experience**: Lower latency, higher throughput
- **Resource efficiency**: Maximum hardware utilization
- **Scalability**: Handle variable workload patterns

### **Implementation Considerations**
- **Complexity**: Requires sophisticated request management
- **Memory management**: Need efficient block allocation algorithms
- **Load balancing**: Dynamic request scheduling
- **Monitoring**: More metrics to track and optimize

---

## 🚀 Advanced Inference Techniques

### **Speculative Decoding**
```python
# Use smaller model to predict multiple tokens ahead
def speculative_decode(large_model, small_model, input_ids):
    # Small model generates candidate tokens quickly
    candidates = small_model.generate_candidates(input_ids, num_tokens=4)
    
    # Large model verifies in parallel
    verified_tokens = large_model.verify_parallel(input_ids, candidates)
    
    return verified_tokens
```

### **Parallel Sampling**
```python
# Generate multiple sequences in parallel
def parallel_sampling(model, prompt, num_samples=4):
    # Use same KV cache for shared prefix
    shared_prefix = model.encode(prompt)
    
    # Branch at sampling points
    branches = model.parallel_decode(shared_prefix, num_samples)
    
    return branches
```

### **Quantized Inference**
```python
# INT8 inference with calibration
def quantized_inference(model, calibration_data):
    # Calibrate quantization parameters
    quantized_model = quantize_model(model, calibration_data)
    
    # Use INT8 kernels for inference
    return quantized_model.inference_optimized()
```

---

## 🔬 Production Deployment Patterns

### **Multi-Tier Architecture**
```
Load Balancer
     ↓
Request Router (determines model size needed)
     ↓
┌─────────────┬─────────────┬─────────────┐
│ Small Model │ Medium Model│ Large Model │
│ (Fast)      │ (Balanced)  │ (Accurate)  │
│ vLLM Engine │ vLLM Engine │ vLLM Engine │
└─────────────┴─────────────┴─────────────┘
```

### **Auto-scaling Strategy**
```python
# Scale based on queue depth and latency
def should_scale_up(metrics):
    return (
        metrics.queue_depth > 50 or
        metrics.p95_latency > 2000 or  # 2s
        metrics.gpu_utilization < 60
    )
```

---

**Next: Chapter 7 - Distributed Training Strategies** 🌐

*In the next chapter, we'll explore how to scale training across multiple GPUs and nodes, covering data parallelism, model parallelism, and pipeline parallelism strategies.*