# PyTorch Tutorial 21: Modern LLM Inference Optimization (2025)

**The Reality**: Inference costs account for **90% of production ML expenses**. For LLMs serving millions of users, a 2x speedup can save millions of dollars annually.

In 2024-2025, several breakthrough techniques emerged that fundamentally changed how we deploy LLMs:
- **KV Cache Optimization**: Reducing memory bottlenecks
- **Speculative Decoding**: 2-3x speedup with zero accuracy loss
- **PagedAttention**: Treating GPU memory like OS virtual memory
- **Continuous Batching**: No more waiting for slow requests

This notebook teaches you the **state-of-the-art** techniques used by OpenAI, Anthropic, and Meta in production.

## Learning Objectives
1. Understand why LLM inference is **memory-bandwidth-bound**
2. Implement **KV caching** from scratch
3. Learn **speculative decoding** (the technique behind GPT-4 Turbo's speed)
4. Understand **PagedAttention** and **continuous batching**
5. Deploy models with **vLLM** and compare to naive implementations

---

## Part 1: The Problem - Why is LLM Inference Slow?

### Vocabulary First

- **Autoregressive Generation**: Generating one token at a time, where each new token depends on all previous tokens
- **Memory-Bandwidth-Bound**: Performance limited by how fast we can load data from memory, not by computation speed
- **KV Cache**: Storing Key and Value tensors from previous tokens to avoid recomputing them
- **Time-to-First-Token (TTFT)**: Latency before the first output token appears
- **Throughput**: Total tokens generated per second across all requests

### The Core Issue

**Problem**: In autoregressive decoding, we must load **ALL model weights and KV cache** just to generate **ONE token**.

For a 7B parameter model:
- Model weights: ~14GB (FP16)
- KV cache per request (2048 tokens): ~2GB
- Loading this to compute 1 token = huge waste!

**Formula**: 
```
Latency per token = (Model Size + KV Cache Size) / Memory Bandwidth
```

This is why optimizing memory access is more important than optimizing compute for LLMs!

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory Bandwidth: ~{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Part 2: KV Caching - The Foundation

### Without KV Cache (Naive)
For each new token, recompute attention for ALL previous tokens.

### With KV Cache
Store the Key and Value projections of past tokens and reuse them.

**Speedup**: From O(nÂ²) to O(n) for sequence length n!

Let's implement this from scratch:

In [None]:
class SimpleAttentionWithKVCache(nn.Module):
    """Attention mechanism with KV caching for efficient autoregressive generation."""
    
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.out = nn.Linear(dim, dim, bias=False)
    
    def forward(self, x, kv_cache=None, use_cache=False):
        """
        Args:
            x: Input tensor [batch, seq_len, dim]
            kv_cache: Tuple of (k_cache, v_cache) from previous steps
            use_cache: Whether to return updated cache
        
        Returns:
            output: Attention output
            new_cache: Updated (k, v) cache (if use_cache=True)
        """
        B, T, C = x.shape
        
        # Compute Q, K, V
        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, T, head_dim)
        
        # If we have cached K, V, concatenate with new ones
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            k = torch.cat([k_cache, k], dim=2)  # Concat along seq dimension
            v = torch.cat([v_cache, v], dim=2)
        
        # Attention: Q @ K^T / sqrt(d)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        
        # Output: Attention @ V
        out = attn @ v  # (B, heads, T, head_dim)
        out = out.transpose(1, 2).reshape(B, T, C)
        out = self.out(out)
        
        if use_cache:
            return out, (k, v)
        return out

# Test it
attn = SimpleAttentionWithKVCache(dim=512, num_heads=8).to(device)

# Simulate autoregressive generation
print("Simulating autoregressive generation with KV cache:\n")

x = torch.randn(1, 1, 512).to(device)  # First token
kv_cache = None

for step in range(5):
    with torch.no_grad():
        out, kv_cache = attn(x, kv_cache=kv_cache, use_cache=True)
    
    k_cache, v_cache = kv_cache
    print(f"Step {step+1}: K cache shape = {k_cache.shape}")
    
    # Next token (in reality, this would be from the model output)
    x = torch.randn(1, 1, 512).to(device)

print("\nâœ… KV cache grows with each token, avoiding recomputation!")

### KV Cache Optimization Techniques (2025)

#### 1. **Quantization**
Reduce KV cache from FP16 to INT8 or even INT4.
- **Benefit**: 2-4x memory savings
- **Cost**: Minimal accuracy loss (<1% in most cases)

#### 2. **Entropy-Based Allocation**
Some layers have more "dispersed" attention (high entropy) and need more cache.
- Allocate larger KV budgets to high-entropy layers
- Can save 4-5% memory with same performance

#### 3. **PagedAttention** (vLLM)
Treat KV cache like OS virtual memory:
- Split cache into fixed-size "pages" (e.g., 64 tokens)
- Store non-contiguously in GPU memory
- Share pages across requests (e.g., system prompts)

**Impact**: 2-4x higher throughput, near-zero memory waste

In [None]:
# Conceptual example: KV Cache Quantization

def quantize_kv_cache(k, v, bits=8):
    """Quantize KV cache to save memory."""
    def quantize_tensor(x, bits):
        # Find min/max
        x_min, x_max = x.min(), x.max()
        
        # Quantize to N bits
        scale = (x_max - x_min) / (2**bits - 1)
        x_quant = ((x - x_min) / scale).round()
        
        return x_quant.to(torch.int8), scale, x_min
    
    k_quant, k_scale, k_min = quantize_tensor(k, bits)
    v_quant, v_scale, v_min = quantize_tensor(v, bits)
    
    return (k_quant, k_scale, k_min), (v_quant, v_scale, v_min)

def dequantize_tensor(x_quant, scale, x_min):
    """Dequantize back to FP16."""
    return x_quant.float() * scale + x_min

# Test
k = torch.randn(1, 8, 100, 64)  # (batch, heads, seq, dim)
v = torch.randn(1, 8, 100, 64)

print(f"Original K memory: {k.element_size() * k.nelement() / 1e6:.2f} MB")

k_quant, v_quant = quantize_kv_cache(k, v, bits=8)
k_q, k_scale, k_min = k_quant

print(f"Quantized K memory: {k_q.element_size() * k_q.nelement() / 1e6:.2f} MB")
print(f"Memory savings: {(1 - k_q.element_size() / k.element_size()) * 100:.1f}%")

# Verify we can reconstruct
k_reconstructed = dequantize_tensor(k_q, k_scale, k_min)
error = (k - k_reconstructed).abs().mean()
print(f"Reconstruction error: {error:.6f} (very small!)")

## Part 3: Speculative Decoding - The 2-3x Speedup Trick

### The Idea
Use a **small, fast "draft" model** to predict multiple tokens ahead, then **verify** with the large target model in parallel.

### How it Works
1. Draft model generates K candidate tokens (fast!)
2. Target model scores all K tokens in **one forward pass** (parallel!)
3. Accept tokens while predictions match, reject the rest
4. Repeat

### Why it Works
- Draft model is wrong sometimes, but when it's right, we get K tokens for ~1 forward pass cost
- No accuracy loss - output is identical to standard decoding!
- OpenAI's GPT-4 Turbo uses this technique

**Typical Speedup**: 2-3x when draft model has >60% token agreement

In [None]:
# Simulated Speculative Decoding Example

class DraftModel:
    """Simulates a small, fast model."""
    def generate_k_tokens(self, context, k=5):
        # In reality, this would be a small LM (e.g., 1B params)
        # For demo, we'll just return random predictions
        return torch.randint(0, 50000, (k,))

class TargetModel:
    """Simulates the large, accurate model."""
    def verify_tokens(self, context, draft_tokens):
        # In reality, this scores draft_tokens in parallel
        # For demo, we'll randomly accept ~70% of tokens
        acceptance_mask = torch.rand(len(draft_tokens)) > 0.3
        
        # Find first rejection
        if acceptance_mask.all():
            return draft_tokens, len(draft_tokens)
        else:
            first_reject = (~acceptance_mask).nonzero()[0].item()
            return draft_tokens[:first_reject], first_reject

# Simulate speculative decoding
draft = DraftModel()
target = TargetModel()

total_tokens = 0
forward_passes = 0
context = "Once upon a time"  # Initial prompt

print("Simulating Speculative Decoding:\n")
for step in range(10):
    # Draft model generates K tokens
    k = 5
    draft_tokens = draft.generate_k_tokens(context, k=k)
    
    # Target model verifies in ONE forward pass
    accepted_tokens, num_accepted = target.verify_tokens(context, draft_tokens)
    forward_passes += 1
    
    total_tokens += num_accepted
    
    print(f"Step {step+1}: Drafted {k} tokens, accepted {num_accepted}")

speedup = total_tokens / forward_passes
print(f"\nâœ… Generated {total_tokens} tokens in {forward_passes} forward passes")
print(f"Effective speedup: {speedup:.2f}x (vs 1.0x for standard decoding)")

## Part 4: Continuous Batching - Never Wait for Slow Requests

### The Old Way (Static Batching)
- Batch 8 requests together
- Wait for ALL 8 to finish before starting new requests
- **Problem**: If one request generates 1000 tokens and others finish in 50 tokens, 7 GPUs sit idle!

### The New Way (Continuous Batching)
- As soon as a request finishes, immediately inject a new one into the batch
- GPU never sits idle
- **Benefit**: 2-10x higher throughput!

### Implementation
vLLM and TensorRT-LLM do this automatically. The key insight:
- Each request has its own KV cache
- Batch size can change dynamically
- Use PagedAttention to handle variable-length caches

## Part 5: Putting It All Together - Using vLLM

vLLM combines all these techniques:
- âœ… KV caching with PagedAttention
- âœ… Continuous batching
- âœ… Optional speculative decoding
- âœ… Quantization support

### Installation
```bash
pip install vllm
```

### Basic Usage

In [None]:
# Example: Using vLLM (commented out since it requires GPU + installation)

code_example = '''
from vllm import LLM, SamplingParams

# Initialize LLM with optimizations
llm = LLM(
    model="meta-llama/Llama-3-8B",
    tensor_parallel_size=1,           # Use 1 GPU
    gpu_memory_utilization=0.95,      # Use 95% of GPU memory
    enable_prefix_caching=True,       # Cache common prefixes (system prompts)
    max_model_len=4096,               # Max sequence length
    enforce_eager=False,              # Use CUDA graphs for speed
)

# Sampling parameters
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=100
)

# Batch requests (continuous batching happens automatically!)
prompts = [
    "Explain quantum computing in simple terms.",
    "Write a Python function to reverse a string.",
    "What are the key benefits of PyTorch?"
]

# Generate (vLLM handles batching, KV cache, everything!)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    print(f"Prompt: {output.prompt}")
    print(f"Output: {output.outputs[0].text}")
    print("-" * 80)
'''

print("vLLM Usage Example:")
print(code_example)
print("\nâš¡ With vLLM, you get 2-4x higher throughput compared to Hugging Face Transformers!")

## Part 6: Performance Comparison

### Typical Benchmarks (Llama-3 8B on A100 GPU)

| Method | Throughput (tokens/sec) | Latency (TTFT) | Memory Usage |
|--------|------------------------|----------------|---------------|
| **Naive PyTorch** | 50 | 200ms | 24GB |
| **HuggingFace + KV Cache** | 150 | 100ms | 20GB |
| **vLLM (PagedAttention + Batching)** | 450 | 60ms | 14GB |
| **TensorRT-LLM (FP8 + All Tricks)** | 800 | 35ms | 10GB |

**Key Takeaways**:
1. KV caching alone gives ~3x speedup
2. PagedAttention + continuous batching â†’ another 3x
3. Hardware-specific optimizations (TensorRT) â†’ another 1.8x

**Total**: 16x faster inference with modern techniques!

## Summary: The 2025 LLM Inference Stack

### Essential Techniques
1. **KV Caching** - Foundation (3x speedup)
2. **KV Cache Quantization** - INT8/INT4 (2-4x memory savings)
3. **PagedAttention** - Efficient memory management (2x higher batch size)
4. **Continuous Batching** - No GPU idle time (2-10x throughput)
5. **Speculative Decoding** - Multiple tokens per pass (2-3x speedup)

### Production Frameworks
- **vLLM**: Best for most use cases, easy integration
- **TensorRT-LLM**: Maximum performance on NVIDIA GPUs
- **SGLang**: Emerging, focuses on structured generation

### What FAANG Expects You to Know
âœ… Why LLM inference is memory-bandwidth-bound
âœ… How KV caching works and its memory/compute trade-offs
âœ… What speculative decoding is and when to use it
âœ… Difference between static and continuous batching
âœ… How to deploy with vLLM or TensorRT-LLM
âœ… Cost optimization strategies (inference = 90% of costs!)

### Further Reading
- [vLLM Paper: Efficient Memory Management for LLM Serving](https://arxiv.org/abs/2309.06180)
- [Speculative Decoding Blog](https://arxiv.org/abs/2211.17192)
- [FlashAttention-2](https://arxiv.org/abs/2307.08691)

**You now understand cutting-edge LLM inference optimization! ðŸš€**