# Prefix Caching: Reusing KV Cache Between Requests

---

## What You'll Learn

1. **What prefix caching is** and why it matters for inference speed
2. **How KV cache is built** during prefill and reused across requests
3. **Measuring TTFT improvement** with vs without prefix caching
4. **Context engineering**: put unique tokens late, shared tokens early
5. **The wrong way**: unique tokens first kills cache reuse
6. **Real-world applications**: system prompts, multi-turn chat, code completion
7. **Memory savings** from prefix caching

---

### The Core Idea

When multiple requests share the same **prefix** (e.g., the same system prompt), we can compute the KV cache for that prefix **once** and reuse it for all subsequent requests.

Without prefix caching:
```
Request 1: [system prompt] + [user query 1] -> compute ALL from scratch
Request 2: [system prompt] + [user query 2] -> compute ALL from scratch (duplicate work!)
```

With prefix caching:
```
Request 1: [system prompt] + [user query 1] -> compute & CACHE system prompt KV
Request 2: [system prompt] + [user query 2] -> REUSE cached KV, only compute user query
```

In [None]:
# Install dependencies
!pip install torch transformers matplotlib numpy -q

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from typing import Optional, Tuple, List

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

torch.manual_seed(42)

## Part 1: Understanding KV Cache in Prefill

Before we can cache prefixes, let's understand exactly what the KV cache contains and how it's built during prefill.

In [None]:
# Load GPT-2 (small, runs easily on Colab free tier)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
model.eval()

print(f"Model: GPT-2")
print(f"  Layers: {model.config.n_layer}")
print(f"  Heads: {model.config.n_head}")
print(f"  d_model: {model.config.n_embd}")
print(f"  d_head: {model.config.n_embd // model.config.n_head}")

In [None]:
def get_kv_cache(model, input_ids):
    """Run prefill and extract the KV cache."""
    with torch.no_grad():
        outputs = model(input_ids, use_cache=True)
    # past_key_values is a tuple of (key, value) for each layer
    # Each key/value: (batch, n_heads, seq_len, d_head)
    return outputs.past_key_values, outputs.logits


# Build KV cache for a simple prompt
prompt = "The weather in San Francisco is"
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

kv_cache, logits = get_kv_cache(model, input_ids)

print(f"Prompt: '{prompt}'")
print(f"Token count: {input_ids.shape[1]}")
print(f"\nKV Cache structure:")
print(f"  Number of layers: {len(kv_cache)}")
print(f"  Key shape per layer: {kv_cache[0][0].shape}")
print(f"  Value shape per layer: {kv_cache[0][1].shape}")

# Calculate memory
total_bytes = 0
for layer_kv in kv_cache:
    for tensor in layer_kv:
        total_bytes += tensor.nelement() * tensor.element_size()

print(f"\nTotal KV cache size: {total_bytes / 1024:.2f} KB")
print(f"  Per layer: {total_bytes / len(kv_cache) / 1024:.2f} KB")
print(f"  Per token: {total_bytes / input_ids.shape[1]:.0f} bytes")

## Part 2: The Weather Example - Shared Prefix

Consider this scenario from the book: a weather chatbot gets many requests with the same system prompt.

```
Shared prefix:  "You are a helpful weather assistant. Provide accurate weather info for..."
Request 1 adds: "San Francisco"
Request 2 adds: "New York City"
Request 3 adds: "London, UK"
```

Without prefix caching, we recompute the system prompt KV cache every time!

In [None]:
# Simulated weather chatbot scenario
system_prompt = (
    "You are a helpful weather assistant. You provide accurate, detailed weather "
    "information including temperature, humidity, wind speed, and precipitation "
    "forecasts. Always include both Fahrenheit and Celsius. Be concise but thorough. "
    "Current date is January 15, 2025. Provide weather information for: "
)

user_queries = [
    "San Francisco, California",
    "New York City, New York",
    "London, United Kingdom",
    "Tokyo, Japan",
    "Sydney, Australia",
]

# Tokenize
prefix_ids = tokenizer.encode(system_prompt)
print(f"System prompt: {len(system_prompt)} chars, {len(prefix_ids)} tokens")
print(f"Number of requests: {len(user_queries)}")

for q in user_queries:
    q_ids = tokenizer.encode(q)
    print(f"  '{q}': {len(q_ids)} tokens")

In [None]:
# Method 1: WITHOUT prefix caching (compute everything from scratch)
def prefill_without_cache(model, tokenizer, system_prompt, user_queries, n_generate=20):
    """Process each request independently - no prefix caching."""
    results = []
    
    for query in user_queries:
        full_prompt = system_prompt + query
        input_ids = tokenizer.encode(full_prompt, return_tensors='pt').to(device)
        
        start = time.perf_counter()
        
        with torch.no_grad():
            # Prefill: compute KV cache for entire prompt
            outputs = model(input_ids, use_cache=True)
            
        ttft = time.perf_counter() - start
        
        # Generate a few tokens
        past = outputs.past_key_values
        next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        generated = [next_token.item()]
        
        gen_start = time.perf_counter()
        for _ in range(n_generate - 1):
            with torch.no_grad():
                outputs = model(next_token, past_key_values=past, use_cache=True)
            past = outputs.past_key_values
            next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
            generated.append(next_token.item())
        gen_time = time.perf_counter() - gen_start
        
        results.append({
            'query': query,
            'ttft_ms': ttft * 1000,
            'gen_time_ms': gen_time * 1000,
            'total_tokens_processed': input_ids.shape[1],
            'text': tokenizer.decode(generated)
        })
    
    return results


# Run without cache
results_no_cache = prefill_without_cache(model, tokenizer, system_prompt, user_queries)

print("WITHOUT Prefix Caching:")
print("=" * 60)
total_ttft = 0
for r in results_no_cache:
    print(f"  {r['query']:<30} TTFT: {r['ttft_ms']:>7.2f} ms  "
          f"(processed {r['total_tokens_processed']} tokens)")
    total_ttft += r['ttft_ms']
print(f"\n  Total TTFT: {total_ttft:.2f} ms")

In [None]:
# Method 2: WITH prefix caching
def prefill_with_cache(model, tokenizer, system_prompt, user_queries, n_generate=20):
    """Cache the system prompt KV and reuse it."""
    results = []
    
    # Step 1: Build the prefix cache ONCE
    prefix_ids = tokenizer.encode(system_prompt, return_tensors='pt').to(device)
    
    cache_start = time.perf_counter()
    with torch.no_grad():
        prefix_outputs = model(prefix_ids, use_cache=True)
    cache_time = time.perf_counter() - cache_start
    
    prefix_kv = prefix_outputs.past_key_values
    print(f"Prefix cache built in {cache_time*1000:.2f} ms ({prefix_ids.shape[1]} tokens)")
    
    # Step 2: For each query, reuse the prefix cache
    for query in user_queries:
        query_ids = tokenizer.encode(query, return_tensors='pt').to(device)
        
        start = time.perf_counter()
        
        with torch.no_grad():
            # Only process the NEW query tokens, with prefix KV already cached
            outputs = model(
                query_ids, 
                past_key_values=prefix_kv,
                use_cache=True
            )
        
        ttft = time.perf_counter() - start
        
        # Generate tokens
        past = outputs.past_key_values
        next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        generated = [next_token.item()]
        
        gen_start = time.perf_counter()
        for _ in range(n_generate - 1):
            with torch.no_grad():
                outputs = model(next_token, past_key_values=past, use_cache=True)
            past = outputs.past_key_values
            next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
            generated.append(next_token.item())
        gen_time = time.perf_counter() - gen_start
        
        results.append({
            'query': query,
            'ttft_ms': ttft * 1000,
            'gen_time_ms': gen_time * 1000,
            'new_tokens_processed': query_ids.shape[1],
            'text': tokenizer.decode(generated)
        })
    
    return results, cache_time * 1000


# Run with cache
results_with_cache, cache_build_time = prefill_with_cache(
    model, tokenizer, system_prompt, user_queries)

print("\nWITH Prefix Caching:")
print("=" * 60)
total_ttft_cached = cache_build_time  # Include the initial cache build
for r in results_with_cache:
    print(f"  {r['query']:<30} TTFT: {r['ttft_ms']:>7.2f} ms  "
          f"(processed {r['new_tokens_processed']} tokens)")
    total_ttft_cached += r['ttft_ms']
print(f"\n  Total TTFT: {total_ttft_cached:.2f} ms (including {cache_build_time:.2f} ms cache build)")

In [None]:
# Compare TTFT
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

queries_short = [q.split(',')[0] for q in user_queries]

# Per-request TTFT
ttft_no_cache = [r['ttft_ms'] for r in results_no_cache]
ttft_with_cache = [r['ttft_ms'] for r in results_with_cache]

x = np.arange(len(queries_short))
width = 0.35

bars1 = axes[0].bar(x - width/2, ttft_no_cache, width, label='No Cache', 
                     color='#e74c3c', edgecolor='black')
bars2 = axes[0].bar(x + width/2, ttft_with_cache, width, label='With Prefix Cache', 
                     color='#2ecc71', edgecolor='black')

axes[0].set_xlabel('Request', fontsize=12)
axes[0].set_ylabel('TTFT (ms)', fontsize=12)
axes[0].set_title('Time to First Token (TTFT) per Request', fontsize=13, fontweight='bold')
axes[0].set_xticks(x)
axes[0].set_xticklabels(queries_short, rotation=15, ha='right')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3, axis='y')

# Cumulative time
cum_no_cache = np.cumsum(ttft_no_cache)
cum_with_cache = np.cumsum([cache_build_time] + ttft_with_cache)

axes[1].plot(range(len(cum_no_cache)), cum_no_cache, 'ro-', linewidth=2, 
             markersize=10, label='No Cache')
axes[1].plot(range(len(cum_with_cache)), cum_with_cache, 'go-', linewidth=2, 
             markersize=10, label='With Prefix Cache')

axes[1].set_xlabel('Number of Requests', fontsize=12)
axes[1].set_ylabel('Cumulative Time (ms)', fontsize=12)
axes[1].set_title('Cumulative Prefill Time', fontsize=13, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

speedup = sum(ttft_no_cache) / (cache_build_time + sum(ttft_with_cache))
print(f"\nOverall speedup: {speedup:.2f}x")
print(f"Savings grow with more requests sharing the same prefix!")

## Part 3: Context Engineering - Token Ordering Matters!

For prefix caching to work, the **shared tokens must come first**. This is a context engineering principle:

**RIGHT way**: `[shared system prompt] + [unique user query]`  
**WRONG way**: `[unique user query] + [shared system prompt]`

With the wrong ordering, no tokens are shared in the prefix, so caching provides zero benefit.

In [None]:
# Demonstrate the impact of token ordering

# RIGHT: Shared prefix first
right_order_prompts = [
    system_prompt + query for query in user_queries
]

# WRONG: Unique query first
wrong_order_prompts = [
    query + " " + system_prompt for query in user_queries
]

# Calculate shared prefix length
def find_shared_prefix_length(prompts, tokenizer):
    """Find the number of shared prefix tokens."""
    token_lists = [tokenizer.encode(p) for p in prompts]
    
    min_len = min(len(t) for t in token_lists)
    shared = 0
    
    for i in range(min_len):
        if all(t[i] == token_lists[0][i] for t in token_lists):
            shared += 1
        else:
            break
    
    return shared


right_shared = find_shared_prefix_length(right_order_prompts, tokenizer)
wrong_shared = find_shared_prefix_length(wrong_order_prompts, tokenizer)

total_right = len(tokenizer.encode(right_order_prompts[0]))
total_wrong = len(tokenizer.encode(wrong_order_prompts[0]))

print("Token Ordering Analysis:")
print("=" * 50)
print(f"\nRIGHT order (shared prefix first):")
print(f"  Shared prefix tokens: {right_shared}/{total_right} ({100*right_shared/total_right:.1f}%)")
print(f"  Cache hit rate: HIGH")
print(f"\nWRONG order (unique query first):")
print(f"  Shared prefix tokens: {wrong_shared}/{total_wrong} ({100*wrong_shared/total_wrong:.1f}%)")
print(f"  Cache hit rate: ZERO")

In [None]:
# Visualize the difference
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# RIGHT ordering
right_tokens = tokenizer.encode(right_order_prompts[0])
colors_right = ['#2ecc71' if i < right_shared else '#e74c3c' for i in range(len(right_tokens))]

axes[0].barh([0], [right_shared], color='#2ecc71', edgecolor='black', 
             height=0.5, label=f'Cached prefix ({right_shared} tokens)')
axes[0].barh([0], [total_right - right_shared], left=[right_shared], 
             color='#e74c3c', edgecolor='black', height=0.5, 
             label=f'Unique query ({total_right - right_shared} tokens)')
axes[0].set_title('RIGHT: System Prompt First (High Cache Reuse)', 
                   fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].set_xlim(0, total_right + 5)
axes[0].set_xlabel('Token Position')
axes[0].set_yticks([])

# WRONG ordering
query_len = len(tokenizer.encode(user_queries[0]))
axes[1].barh([0], [query_len], color='#e74c3c', edgecolor='black', 
             height=0.5, label=f'Unique query ({query_len} tokens)')
axes[1].barh([0], [total_wrong - query_len], left=[query_len], 
             color='#95a5a6', edgecolor='black', height=0.5, 
             label=f'System prompt ({total_wrong - query_len} tokens) - NOT cacheable')
axes[1].set_title('WRONG: Unique Query First (No Cache Reuse)', 
                   fontsize=13, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].set_xlim(0, total_wrong + 5)
axes[1].set_xlabel('Token Position')
axes[1].set_yticks([])

plt.suptitle('Context Engineering: Token Order Determines Cache Effectiveness',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## Part 4: Scaling Analysis - How Much Does Prefix Length Matter?

In [None]:
# Measure TTFT savings as prefix length varies
base_prompt = "You are a helpful AI assistant. " * 1  # Will be repeated
unique_suffix = "What is 2+2?"

prefix_multipliers = [1, 2, 5, 10, 20, 50]
n_requests = 5

results_scaling = []

for mult in prefix_multipliers:
    prefix = base_prompt * mult
    prefix_tokens = len(tokenizer.encode(prefix))
    suffix_tokens = len(tokenizer.encode(unique_suffix))
    
    # Without cache: process full prompt each time
    full_ids = tokenizer.encode(prefix + unique_suffix, return_tensors='pt').to(device)
    
    times_no_cache = []
    for _ in range(n_requests):
        start = time.perf_counter()
        with torch.no_grad():
            model(full_ids, use_cache=True)
        times_no_cache.append((time.perf_counter() - start) * 1000)
    
    # With cache: build cache once, then only process suffix
    prefix_ids = tokenizer.encode(prefix, return_tensors='pt').to(device)
    suffix_ids = tokenizer.encode(unique_suffix, return_tensors='pt').to(device)
    
    # Build cache
    with torch.no_grad():
        prefix_out = model(prefix_ids, use_cache=True)
    prefix_kv = prefix_out.past_key_values
    
    times_with_cache = []
    for _ in range(n_requests):
        start = time.perf_counter()
        with torch.no_grad():
            model(suffix_ids, past_key_values=prefix_kv, use_cache=True)
        times_with_cache.append((time.perf_counter() - start) * 1000)
    
    results_scaling.append({
        'prefix_tokens': prefix_tokens,
        'suffix_tokens': suffix_tokens,
        'avg_no_cache': np.mean(times_no_cache),
        'avg_with_cache': np.mean(times_with_cache),
        'speedup': np.mean(times_no_cache) / np.mean(times_with_cache)
    })

print(f"{'Prefix Tokens':>15} {'No Cache (ms)':>15} {'With Cache (ms)':>15} {'Speedup':>10}")
print("=" * 60)
for r in results_scaling:
    print(f"{r['prefix_tokens']:>15} {r['avg_no_cache']:>15.2f} {r['avg_with_cache']:>15.2f} {r['speedup']:>9.2f}x")

In [None]:
# Plot scaling results
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

prefix_tokens = [r['prefix_tokens'] for r in results_scaling]
no_cache_times = [r['avg_no_cache'] for r in results_scaling]
with_cache_times = [r['avg_with_cache'] for r in results_scaling]
speedups = [r['speedup'] for r in results_scaling]

# TTFT comparison
axes[0].plot(prefix_tokens, no_cache_times, 'ro-', linewidth=2, markersize=8, label='No Cache')
axes[0].plot(prefix_tokens, with_cache_times, 'go-', linewidth=2, markersize=8, label='With Prefix Cache')
axes[0].set_xlabel('Prefix Length (tokens)', fontsize=12)
axes[0].set_ylabel('TTFT (ms)', fontsize=12)
axes[0].set_title('TTFT vs Prefix Length', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Speedup
axes[1].plot(prefix_tokens, speedups, 'bs-', linewidth=2, markersize=10)
axes[1].set_xlabel('Prefix Length (tokens)', fontsize=12)
axes[1].set_ylabel('Speedup (x)', fontsize=12)
axes[1].set_title('Prefix Cache Speedup vs Prefix Length', fontsize=13, fontweight='bold')
axes[1].grid(True, alpha=0.3)

for i, (pt, sp) in enumerate(zip(prefix_tokens, speedups)):
    axes[1].annotate(f'{sp:.1f}x', (pt, sp), textcoords="offset points",
                     xytext=(0, 10), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

print("Longer shared prefix = more savings from caching!")
print("The cached TTFT stays nearly constant regardless of prefix length.")

## Part 5: Real-World Use Cases

Prefix caching is valuable in several common scenarios. Let's demonstrate each one.

In [None]:
# Use Case 1: System Prompts
# Many providers cache the system prompt across all user requests

system_prompts = {
    'Customer Service': (
        "You are a customer service agent for TechCorp. Always be polite, "
        "apologize for any inconvenience, and try to resolve the issue. "
        "If you cannot resolve it, escalate to a human agent. "
        "Never share internal pricing or policies. "
    ),
    'Code Review': (
        "You are a senior software engineer conducting code reviews. "
        "Focus on: correctness, performance, readability, and security. "
        "Provide specific line-by-line feedback. Suggest improvements "
        "with code examples. Be constructive but thorough. "
    ),
    'Legal Assistant': (
        "You are a legal research assistant. Provide information about "
        "legal concepts and precedents. Always include disclaimers that "
        "you are not providing legal advice. Cite relevant cases when "
        "possible. Be precise with legal terminology. "
    ),
}

print("System Prompt Lengths (tokens):")
print("=" * 40)
for name, prompt in system_prompts.items():
    tokens = len(tokenizer.encode(prompt))
    print(f"  {name:<20} {tokens:>5} tokens")

print("\n=> All these would benefit from prefix caching.")
print("   If 1000 users hit the customer service bot,")
print("   the system prompt KV is computed only ONCE.")

In [None]:
# Use Case 2: Multi-turn Chat
# Each turn extends the context, and all previous turns form the prefix

chat_turns = [
    "User: Hi, I need help with my Python code.\nAssistant: Of course! What issue are you facing?\n",
    "User: I'm getting a TypeError when I try to concatenate a string and integer.\nAssistant: That's a common issue. In Python, you need to convert the integer to a string first using str().\n",
    "User: Can you show me an example?\nAssistant: Sure! Instead of 'Hello ' + 42, use 'Hello ' + str(42).\n",
    "User: What about f-strings?\n",  # Latest turn (unique part)
]

# The prefix grows with each turn
for i in range(1, len(chat_turns) + 1):
    context = "".join(chat_turns[:i])
    tokens = len(tokenizer.encode(context))
    prefix = "".join(chat_turns[:i-1]) if i > 1 else ""
    prefix_tokens = len(tokenizer.encode(prefix)) if prefix else 0
    new_tokens = tokens - prefix_tokens
    
    print(f"Turn {i}: {tokens:>4} total tokens "
          f"(prefix: {prefix_tokens:>4}, new: {new_tokens:>3}) "
          f"-> {100*prefix_tokens/max(tokens,1):.0f}% cacheable")

print("\n=> By turn 4, 90%+ of tokens are from previous turns (all cacheable).")
print("   This is why multi-turn chat gets faster with prefix caching!")

In [None]:
# Use Case 3: Code Completion (same file context, different cursor positions)

code_context = '''
import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x):
        # Self-attention with residual connection
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        # FFN with residual connection
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

'''

# Different completions at the cursor
completions = [
    "    def configure_optimizers(self",  # User types method name
    "class Encoder(nn.Module):",           # User starts new class  
    "# Training loop",                     # User adds comment
]

code_tokens = len(tokenizer.encode(code_context))
print(f"Code context: {code_tokens} tokens (this is the prefix)")
print(f"\nDifferent completions at the cursor:")
for c in completions:
    c_tokens = len(tokenizer.encode(c))
    print(f"  '{c}' -> {c_tokens} new tokens")

print(f"\n=> All completions share the same {code_tokens}-token prefix.")
print(f"   With prefix caching, each completion only needs to process ~5-10 tokens.")

## Part 6: Memory Analysis

Prefix caching saves compute (TTFT) but **costs memory** - we need to store the cached KV tensors. Let's analyze the tradeoff.

In [None]:
def kv_cache_memory(n_layers, n_heads, d_head, seq_len, dtype_bytes=2):
    """Calculate KV cache memory for a given sequence length."""
    # Key and Value per layer: (n_heads, seq_len, d_head)
    # Total: 2 (K+V) * n_layers * n_heads * seq_len * d_head * dtype_bytes
    total_bytes = 2 * n_layers * n_heads * seq_len * d_head * dtype_bytes
    return total_bytes


# GPT-2 config
gpt2_config = {
    'n_layers': 12, 'n_heads': 12, 'd_head': 64, 'name': 'GPT-2 (124M)'
}

# Larger model configs for comparison
configs = [
    {'n_layers': 12, 'n_heads': 12, 'd_head': 64, 'name': 'GPT-2 (124M)'},
    {'n_layers': 32, 'n_heads': 32, 'd_head': 128, 'name': 'LLaMA-7B'},
    {'n_layers': 40, 'n_heads': 40, 'd_head': 128, 'name': 'LLaMA-13B'},
    {'n_layers': 80, 'n_heads': 64, 'd_head': 128, 'name': 'LLaMA-70B'},
]

prefix_lengths = [128, 256, 512, 1024, 2048, 4096]

print(f"{'Model':<15} {'Prefix Len':>12} {'KV Cache Size':>15} {'Per Cached Req':>15}")
print("=" * 60)

for config in configs:
    for pl in [512, 2048]:
        mem = kv_cache_memory(config['n_layers'], config['n_heads'], 
                             config['d_head'], pl)
        print(f"{config['name']:<15} {pl:>12} {mem/1e6:>12.1f} MB")
    print()

In [None]:
# Visualize memory vs compute savings tradeoff
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Memory cost of caching different prefix lengths
config = configs[1]  # LLaMA-7B
memories = [kv_cache_memory(config['n_layers'], config['n_heads'], 
                            config['d_head'], pl) / 1e6 
            for pl in prefix_lengths]

axes[0].bar(range(len(prefix_lengths)), memories, color='#3498db', edgecolor='black')
axes[0].set_xticks(range(len(prefix_lengths)))
axes[0].set_xticklabels(prefix_lengths)
axes[0].set_xlabel('Prefix Length (tokens)', fontsize=12)
axes[0].set_ylabel('KV Cache Memory (MB)', fontsize=12)
axes[0].set_title(f'Memory Cost of Prefix Cache\n({config["name"]})', 
                   fontsize=13, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')

for i, (pl, mem) in enumerate(zip(prefix_lengths, memories)):
    axes[0].text(i, mem + 10, f'{mem:.0f} MB', ha='center', fontweight='bold', fontsize=9)

# Savings: N requests sharing a prefix
n_requests_range = [1, 5, 10, 50, 100, 500, 1000]
prefix_len = 1024
prefix_mem = kv_cache_memory(config['n_layers'], config['n_heads'], 
                              config['d_head'], prefix_len)

mem_without = [n * prefix_mem / 1e9 for n in n_requests_range]  # Each request has its own
mem_with = [prefix_mem / 1e9 + n * kv_cache_memory(config['n_layers'], config['n_heads'], 
            config['d_head'], 128) / 1e9 for n in n_requests_range]  # Shared + unique part

axes[1].plot(n_requests_range, mem_without, 'ro-', linewidth=2, markersize=8, label='Without Cache')
axes[1].plot(n_requests_range, mem_with, 'go-', linewidth=2, markersize=8, label='With Prefix Cache')
axes[1].set_xlabel('Number of Concurrent Requests', fontsize=12)
axes[1].set_ylabel('Total KV Memory (GB)', fontsize=12)
axes[1].set_title(f'Memory Savings vs Request Count\n(Prefix: {prefix_len} tokens, {config["name"]})',
                   fontsize=13, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
axes[1].set_xscale('log')

plt.tight_layout()
plt.show()

print(f"With 1000 requests sharing a {prefix_len}-token prefix:")
print(f"  Without caching: {mem_without[-1]:.1f} GB")
print(f"  With caching: {mem_with[-1]:.1f} GB")
print(f"  Savings: {(1 - mem_with[-1]/mem_without[-1])*100:.1f}%")

## Part 7: Prefix Caching Implementation Details

Let's look at how prefix caching actually works at the KV cache level.

In [None]:
class PrefixCacheManager:
    """Simple prefix cache manager for demonstration."""
    
    def __init__(self, model, tokenizer, device='cpu'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.cache = {}  # hash -> (token_ids, kv_cache)
        self.stats = {'hits': 0, 'misses': 0, 'tokens_saved': 0}
    
    def _hash_prefix(self, token_ids):
        """Hash token IDs for cache lookup."""
        return hash(tuple(token_ids))
    
    def get_or_compute(self, full_text):
        """Process text, using cached prefix if available."""
        full_ids = self.tokenizer.encode(full_text)
        
        # Find the longest cached prefix
        best_match_len = 0
        best_kv = None
        
        for length in range(len(full_ids), 0, -1):
            prefix_hash = self._hash_prefix(full_ids[:length])
            if prefix_hash in self.cache:
                best_match_len = length
                best_kv = self.cache[prefix_hash]
                break
        
        if best_kv is not None:
            # Cache hit! Only process the remaining tokens
            self.stats['hits'] += 1
            self.stats['tokens_saved'] += best_match_len
            
            remaining_ids = torch.tensor([full_ids[best_match_len:]], device=self.device)
            
            start = time.perf_counter()
            with torch.no_grad():
                outputs = self.model(remaining_ids, past_key_values=best_kv, use_cache=True)
            elapsed = time.perf_counter() - start
            
            return outputs, elapsed, best_match_len, len(full_ids) - best_match_len
        else:
            # Cache miss: process everything
            self.stats['misses'] += 1
            
            input_ids = torch.tensor([full_ids], device=self.device)
            
            start = time.perf_counter()
            with torch.no_grad():
                outputs = self.model(input_ids, use_cache=True)
            elapsed = time.perf_counter() - start
            
            # Cache the prefix for future use
            prefix_hash = self._hash_prefix(full_ids)
            self.cache[prefix_hash] = outputs.past_key_values
            
            return outputs, elapsed, 0, len(full_ids)
    
    def cache_prefix(self, prefix_text):
        """Pre-cache a prefix."""
        prefix_ids = self.tokenizer.encode(prefix_text)
        input_ids = torch.tensor([prefix_ids], device=self.device)
        
        with torch.no_grad():
            outputs = self.model(input_ids, use_cache=True)
        
        prefix_hash = self._hash_prefix(prefix_ids)
        self.cache[prefix_hash] = outputs.past_key_values
        
        print(f"Cached prefix: {len(prefix_ids)} tokens")
        return prefix_ids


# Use the cache manager
cache_mgr = PrefixCacheManager(model, tokenizer, device)

# Pre-cache the system prompt
prefix_ids = cache_mgr.cache_prefix(system_prompt)

# Process requests
print("\nProcessing requests:")
print("=" * 70)

for query in user_queries:
    full_text = system_prompt + query
    outputs, elapsed, cached_tokens, new_tokens = cache_mgr.get_or_compute(full_text)
    
    status = "HIT" if cached_tokens > 0 else "MISS"
    print(f"  [{status}] {query:<30} | Cached: {cached_tokens:>4} | New: {new_tokens:>3} | "
          f"Time: {elapsed*1000:.2f} ms")

print(f"\nCache stats: {cache_mgr.stats}")
print(f"Total tokens saved: {cache_mgr.stats['tokens_saved']}")

## Part 8: The Cost of Prefix Mismatch

What happens when the prefix is *almost* the same but has small differences? Even a single token difference breaks the cache.

In [None]:
# Demonstrate prefix mismatch sensitivity
prefix_v1 = "You are a helpful AI assistant. Be concise and accurate. "
prefix_v2 = "You are a helpful AI assistant. Be concise and accurate! "  # ! instead of .
prefix_v3 = "You are a helpful AI assistant.  Be concise and accurate. "  # Extra space

# Tokenize and compare
ids_v1 = tokenizer.encode(prefix_v1)
ids_v2 = tokenizer.encode(prefix_v2)
ids_v3 = tokenizer.encode(prefix_v3)

print("Prefix sensitivity analysis:")
print("=" * 60)

print(f"\nV1 (period):     {ids_v1}")
print(f"V2 (excl. mark): {ids_v2}")
print(f"V3 (extra space): {ids_v3}")

# Find divergence point
def find_divergence(a, b):
    for i in range(min(len(a), len(b))):
        if a[i] != b[i]:
            return i
    return min(len(a), len(b))

div_12 = find_divergence(ids_v1, ids_v2)
div_13 = find_divergence(ids_v1, ids_v3)

print(f"\nV1 vs V2 diverge at token {div_12}/{len(ids_v1)} ({100*div_12/len(ids_v1):.0f}% shared)")
print(f"V1 vs V3 diverge at token {div_13}/{len(ids_v1)} ({100*div_13/len(ids_v1):.0f}% shared)")
print(f"\nEven tiny text changes can break prefix cache alignment!")
print(f"Best practice: Use EXACT same system prompt text for all requests.")

In [None]:
# Visualize the prefix match/mismatch
fig, axes = plt.subplots(3, 1, figsize=(14, 7))

max_len = max(len(ids_v1), len(ids_v2), len(ids_v3))

for idx, (ids, label) in enumerate([(ids_v1, 'V1 (period)'), 
                                      (ids_v2, 'V2 (excl. mark)'),
                                      (ids_v3, 'V3 (extra space)')]):
    colors = []
    for i, token_id in enumerate(ids):
        if i < len(ids_v1) and token_id == ids_v1[i]:
            colors.append('#2ecc71')  # Match
        else:
            colors.append('#e74c3c')  # Mismatch
    
    axes[idx].bar(range(len(ids)), [1]*len(ids), color=colors, edgecolor='black', linewidth=0.5)
    axes[idx].set_ylabel(label, fontsize=10)
    axes[idx].set_xlim(-0.5, max_len + 0.5)
    axes[idx].set_yticks([])

axes[2].set_xlabel('Token Position', fontsize=12)

# Add legend
import matplotlib.patches as mpatches
green_patch = mpatches.Patch(color='#2ecc71', label='Matches V1 (cacheable)')
red_patch = mpatches.Patch(color='#e74c3c', label='Differs from V1 (breaks cache)')
fig.legend(handles=[green_patch, red_patch], loc='upper right', fontsize=11)

plt.suptitle('Prefix Cache Sensitivity: Tiny Changes Break the Cache',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 9: Comprehensive Benchmark

In [None]:
# Run a comprehensive benchmark
n_trials = 10
n_generate = 30

def benchmark_generation(model, tokenizer, full_prompt, prefix_kv=None, n_generate=30):
    """Benchmark TTFT and generation speed."""
    if prefix_kv is not None:
        # With prefix cache: only process the suffix
        prefix_text = system_prompt
        suffix = full_prompt[len(prefix_text):]
        input_ids = tokenizer.encode(suffix, return_tensors='pt').to(device)
        past = prefix_kv
    else:
        # Without: process everything
        input_ids = tokenizer.encode(full_prompt, return_tensors='pt').to(device)
        past = None
    
    # Prefill (TTFT)
    start = time.perf_counter()
    with torch.no_grad():
        outputs = model(input_ids, past_key_values=past, use_cache=True)
    ttft = time.perf_counter() - start
    
    # Generate
    past = outputs.past_key_values
    next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
    
    gen_start = time.perf_counter()
    for _ in range(n_generate):
        with torch.no_grad():
            outputs = model(next_token, past_key_values=past, use_cache=True)
        past = outputs.past_key_values
        next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
    gen_time = time.perf_counter() - gen_start
    
    return {
        'ttft_ms': ttft * 1000,
        'gen_time_ms': gen_time * 1000,
        'tps': n_generate / gen_time
    }


# Build prefix cache
prefix_ids = tokenizer.encode(system_prompt, return_tensors='pt').to(device)
with torch.no_grad():
    prefix_out = model(prefix_ids, use_cache=True)
prefix_kv = prefix_out.past_key_values

# Benchmark
full_prompt = system_prompt + "San Francisco, California"

ttft_no_cache = []
ttft_with_cache = []
tps_no_cache = []
tps_with_cache = []

for _ in range(n_trials):
    r1 = benchmark_generation(model, tokenizer, full_prompt, prefix_kv=None)
    r2 = benchmark_generation(model, tokenizer, full_prompt, prefix_kv=prefix_kv)
    
    ttft_no_cache.append(r1['ttft_ms'])
    ttft_with_cache.append(r2['ttft_ms'])
    tps_no_cache.append(r1['tps'])
    tps_with_cache.append(r2['tps'])

print(f"Benchmark Results ({n_trials} trials):")
print("=" * 50)
print(f"{'Metric':<20} {'No Cache':>12} {'With Cache':>12} {'Speedup':>10}")
print("-" * 55)
print(f"{'TTFT (ms)':<20} {np.mean(ttft_no_cache):>12.2f} {np.mean(ttft_with_cache):>12.2f} "
      f"{np.mean(ttft_no_cache)/np.mean(ttft_with_cache):>9.2f}x")
print(f"{'TPS':<20} {np.mean(tps_no_cache):>12.1f} {np.mean(tps_with_cache):>12.1f} "
      f"{np.mean(tps_with_cache)/np.mean(tps_no_cache):>9.2f}x")

In [None]:
# Visualize benchmark results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# TTFT distribution
axes[0].hist(ttft_no_cache, bins=10, alpha=0.7, color='#e74c3c', 
             label='No Cache', edgecolor='black')
axes[0].hist(ttft_with_cache, bins=10, alpha=0.7, color='#2ecc71', 
             label='With Prefix Cache', edgecolor='black')
axes[0].set_xlabel('TTFT (ms)', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('TTFT Distribution', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].axvline(np.mean(ttft_no_cache), color='#e74c3c', linestyle='--', linewidth=2)
axes[0].axvline(np.mean(ttft_with_cache), color='#2ecc71', linestyle='--', linewidth=2)

# Summary bars
metrics = ['TTFT (ms)', 'TPS']
no_cache_vals = [np.mean(ttft_no_cache), np.mean(tps_no_cache)]
with_cache_vals = [np.mean(ttft_with_cache), np.mean(tps_with_cache)]

x = np.arange(2)
width = 0.3
bars1 = axes[1].bar(x - width/2, no_cache_vals, width, label='No Cache', 
                     color='#e74c3c', edgecolor='black')
bars2 = axes[1].bar(x + width/2, with_cache_vals, width, label='With Cache', 
                     color='#2ecc71', edgecolor='black')

axes[1].set_xticks(x)
axes[1].set_xticklabels(metrics)
axes[1].set_title('Performance Comparison', fontsize=13, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---

## Key Takeaways

1. **Prefix caching** stores the KV cache for shared prefixes and reuses it across requests. This avoids redundant prefill computation.

2. **TTFT improvement** scales with prefix length. A 1000-token system prompt cached = 1000 fewer tokens to process per request.

3. **Token ordering matters**: Put shared content (system prompts, chat history) **first** and unique content **last**. This is a core context engineering principle.

4. **Exact match required**: Even a single different token breaks the prefix cache. Use deterministic, exact system prompts.

5. **Common use cases**: System prompts (all users share), multi-turn chat (all prior turns are prefix), code completion (file context is prefix).

6. **Memory tradeoff**: Prefix caching uses extra GPU memory to store cached KV tensors, but saves significant compute and memory when many requests share prefixes.

7. **Savings grow with scale**: More requests sharing a prefix = more compute saved. This is why prefix caching is essential for production inference servers.

---

*Next: We'll set up vLLM, which includes built-in prefix caching support, and see these concepts in action at scale.*