# Day 32: KV Caching Implementation - Part 1

In this notebook, we'll explore Key-Value (KV) caching, a fundamental optimization technique for efficient inference in large language models. We'll implement a basic version of KV caching and measure its impact on inference speed.

## Overview

1. Setup and dependencies
2. Understanding KV caching
3. Implementing KV caching from scratch
4. Measuring performance improvements

## 1. Setup and Dependencies

In [None]:
!pip install -q torch transformers datasets evaluate accelerate matplotlib

In [None]:
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Understanding KV Caching

KV caching is an optimization technique that stores the key (K) and value (V) tensors computed during the forward pass of a transformer model. When generating text token by token, instead of recomputing these tensors for all tokens in each step, we reuse the cached values from previous steps and only compute K and V for the new token.

This significantly reduces the computational cost of autoregressive generation, especially for long sequences.

### 2.1 The Transformer Attention Mechanism

Let's first review how attention works in transformer models:

1. Input tokens are embedded and passed through the model
2. For each layer, we compute query (Q), key (K), and value (V) matrices
3. Attention scores are computed as `softmax(Q * K^T / sqrt(d_k))`
4. The output is `attention_scores * V`

During autoregressive generation, we generate one token at a time. Without KV caching, we would recompute Q, K, and V for all tokens in each step, which is inefficient.

### 2.2 KV Caching Approach

With KV caching:

1. We compute and store K and V for all tokens in the initial input
2. For each new token, we only compute K and V for that token and append to the cache
3. We compute Q only for the new token
4. We use the cached K and V along with the new Q to compute attention

This reduces the computational complexity from O(n²) to O(n) for sequence generation.

## 3. Loading a Pre-trained Model

Let's load a small pre-trained model to demonstrate KV caching.

In [None]:
# Define model name
model_name = "gpt2"  # Using a smaller model for demonstration

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# Print model information
print(f"Model: {model_name}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
print(f"Number of layers: {len(model.transformer.h)}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Number of attention heads: {model.config.num_attention_heads}")

## 4. Implementing KV Caching from Scratch

Now, let's implement a basic version of KV caching for autoregressive generation. We'll create two functions:

1. `generate_without_kv_cache`: Standard generation without KV caching
2. `generate_with_kv_cache`: Optimized generation with KV caching

In [None]:
def generate_without_kv_cache(model, input_ids, max_length=50, temperature=1.0):
    """Generate text without using KV caching."""
    # Start timing
    start_time = time.time()
    
    # Move input to device
    input_ids = input_ids.to(device)
    current_length = input_ids.shape[1]
    
    # Generate tokens one by one
    for _ in range(max_length - current_length):
        # Forward pass through the model
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[:, -1, :]
            
            # Apply temperature
            next_token_logits = next_token_logits / temperature
            
            # Sample from the distribution
            probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append the new token to the sequence
            input_ids = torch.cat([input_ids, next_token], dim=1)
    
    # End timing
    end_time = time.time()
    generation_time = end_time - start_time
    
    return input_ids, generation_time

In [None]:
def generate_with_kv_cache(model, input_ids, max_length=50, temperature=1.0):
    """Generate text using KV caching."""
    # Start timing
    start_time = time.time()
    
    # Move input to device
    input_ids = input_ids.to(device)
    current_length = input_ids.shape[1]
    
    # Initialize the KV cache
    past_key_values = None
    
    # Generate tokens one by one
    for _ in range(max_length - current_length):
        # Forward pass through the model with past_key_values
        with torch.no_grad():
            outputs = model(
                input_ids if past_key_values is None else input_ids[:, -1:],
                past_key_values=past_key_values,
                use_cache=True
            )
            next_token_logits = outputs.logits[:, -1, :]
            
            # Update the KV cache
            past_key_values = outputs.past_key_values
            
            # Apply temperature
            next_token_logits = next_token_logits / temperature
            
            # Sample from the distribution
            probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append the new token to the sequence
            input_ids = torch.cat([input_ids, next_token], dim=1)
    
    # End timing
    end_time = time.time()
    generation_time = end_time - start_time
    
    return input_ids, generation_time

## 5. Measuring Performance Improvements

Now, let's compare the performance of generation with and without KV caching.

In [None]:
# Define a test prompt
prompt = "Artificial intelligence will transform the future by"

# Tokenize the prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")

In [None]:
# Generate text without KV caching
print("Generating without KV caching...")
output_without_cache, time_without_cache = generate_without_kv_cache(
    model, input_ids.clone(), max_length=100, temperature=0.7
)

# Generate text with KV caching
print("Generating with KV caching...")
output_with_cache, time_with_cache = generate_with_kv_cache(
    model, input_ids.clone(), max_length=100, temperature=0.7
)

# Print the results
print(f"\nGeneration time without KV caching: {time_without_cache:.4f} seconds")
print(f"Generation time with KV caching: {time_with_cache:.4f} seconds")
print(f"Speedup: {time_without_cache / time_with_cache:.2f}x")

In [None]:
# Decode the generated text
text_without_cache = tokenizer.decode(output_without_cache[0], skip_special_tokens=True)
text_with_cache = tokenizer.decode(output_with_cache[0], skip_special_tokens=True)

print("Generated text without KV caching:")
print(text_without_cache)
print("\n" + "-"*50 + "\n")
print("Generated text with KV caching:")
print(text_with_cache)

## 6. Analyzing KV Cache Memory Usage

Let's analyze the memory usage of the KV cache for different sequence lengths.

In [None]:
def calculate_kv_cache_size(model, seq_length):
    """Calculate the size of the KV cache for a given sequence length."""
    # Get model configuration
    num_layers = len(model.transformer.h)
    hidden_size = model.config.hidden_size
    num_heads = model.config.num_attention_heads
    head_dim = hidden_size // num_heads
    
    # Calculate size in bytes (assuming FP16 - 2 bytes per element)
    bytes_per_element = 2  # FP16
    
    # Each layer has both K and V caches
    # Each cache has shape [batch_size, num_heads, seq_length, head_dim]
    batch_size = 1
    kv_cache_size = 2 * num_layers * batch_size * num_heads * seq_length * head_dim * bytes_per_element
    
    # Convert to MB
    kv_cache_size_mb = kv_cache_size / (1024 * 1024)
    
    return kv_cache_size_mb

In [None]:
# Calculate KV cache size for different sequence lengths
seq_lengths = [128, 256, 512, 1024, 2048, 4096, 8192]
cache_sizes = [calculate_kv_cache_size(model, length) for length in seq_lengths]

# Print the results
print("KV Cache Size Analysis:")
print("-" * 40)
print(f"{'Sequence Length':<20} {'Cache Size (MB)':<15}")
print("-" * 40)
for length, size in zip(seq_lengths, cache_sizes):
    print(f"{length:<20} {size:<15.2f}")
print("-" * 40)

In [None]:
# Visualize the KV cache size growth
plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, cache_sizes, marker='o', linewidth=2)
plt.title('KV Cache Size vs. Sequence Length', fontsize=14)
plt.xlabel('Sequence Length', fontsize=12)
plt.ylabel('Cache Size (MB)', fontsize=12)
plt.grid(True, alpha=0.3)
plt.xscale('log', base=2)
plt.xticks(seq_lengths, [str(x) for x in seq_lengths])
plt.tight_layout()
plt.show()

## 7. Measuring Speedup for Different Sequence Lengths

Let's measure the speedup provided by KV caching for different sequence lengths.

In [None]:
def measure_speedup(model, tokenizer, prompt, gen_length, num_runs=3):
    """Measure speedup of KV caching for a given prompt and generation length."""
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    # Measure time without KV caching
    times_without_cache = []
    for _ in range(num_runs):
        _, time_without_cache = generate_without_kv_cache(
            model, input_ids.clone(), max_length=len(input_ids[0])+gen_length, temperature=0.7
        )
        times_without_cache.append(time_without_cache)
    avg_time_without_cache = sum(times_without_cache) / len(times_without_cache)
    
    # Measure time with KV caching
    times_with_cache = []
    for _ in range(num_runs):
        _, time_with_cache = generate_with_kv_cache(
            model, input_ids.clone(), max_length=len(input_ids[0])+gen_length, temperature=0.7
        )
        times_with_cache.append(time_with_cache)
    avg_time_with_cache = sum(times_with_cache) / len(times_with_cache)
    
    # Calculate speedup
    speedup = avg_time_without_cache / avg_time_with_cache
    
    return avg_time_without_cache, avg_time_with_cache, speedup

In [None]:
# Define test cases with different generation lengths
gen_lengths = [10, 20, 50, 100]
results = []

# Measure speedup for each generation length
for gen_length in gen_lengths:
    print(f"Measuring speedup for generation length: {gen_length}")
    time_without_cache, time_with_cache, speedup = measure_speedup(
        model, tokenizer, prompt, gen_length
    )
    results.append({
        "gen_length": gen_length,
        "time_without_cache": time_without_cache,
        "time_with_cache": time_with_cache,
        "speedup": speedup
    })
    print(f"  Time without cache: {time_without_cache:.4f}s")
    print(f"  Time with cache: {time_with_cache:.4f}s")
    print(f"  Speedup: {speedup:.2f}x")
    print()

In [None]:
# Visualize the speedup results
plt.figure(figsize=(12, 6))

# Plot generation times
plt.subplot(1, 2, 1)
plt.bar(
    [str(r["gen_length"]) for r in results], 
    [r["time_without_cache"] for r in results],
    label="Without KV Cache",
    alpha=0.7
)
plt.bar(
    [str(r["gen_length"]) for r in results], 
    [r["time_with_cache"] for r in results],
    label="With KV Cache",
    alpha=0.7
)
plt.title("Generation Time Comparison")
plt.xlabel("Generation Length")
plt.ylabel("Time (seconds)")
plt.legend()
plt.grid(axis="y", alpha=0.3)

# Plot speedup
plt.subplot(1, 2, 2)
plt.bar(
    [str(r["gen_length"]) for r in results], 
    [r["speedup"] for r in results],
    color="green"
)
plt.title("Speedup Factor")
plt.xlabel("Generation Length")
plt.ylabel("Speedup (x)")
plt.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Understanding the Trade-offs

KV caching provides significant speedup for autoregressive generation, but it comes with trade-offs:

### 8.1 Memory Usage

The KV cache consumes memory proportional to the sequence length. For large models and long sequences, this can be substantial.

For example, for a model like GPT-3 (175B parameters) with 2048 tokens, the KV cache can be several gigabytes.

In [None]:
# Calculate KV cache size for a large model (hypothetical)
def calculate_large_model_kv_cache(seq_length, num_layers=96, hidden_size=12288, num_heads=96):
    """Calculate KV cache size for a large model like GPT-3."""
    head_dim = hidden_size // num_heads
    bytes_per_element = 2  # FP16
    batch_size = 1
    
    kv_cache_size = 2 * num_layers * batch_size * num_heads * seq_length * head_dim * bytes_per_element
    kv_cache_size_gb = kv_cache_size / (1024 * 1024 * 1024)
    
    return kv_cache_size_gb

# Calculate for different sequence lengths
large_seq_lengths = [1024, 2048, 4096, 8192, 16384, 32768]
large_cache_sizes = [calculate_large_model_kv_cache(length) for length in large_seq_lengths]

# Print the results
print("KV Cache Size for Large Model (GPT-3 scale):")
print("-" * 40)
print(f"{'Sequence Length':<20} {'Cache Size (GB)':<15}")
print("-" * 40)
for length, size in zip(large_seq_lengths, large_cache_sizes):
    print(f"{length:<20} {size:<15.2f}")
print("-" * 40)

### 8.2 Batch Processing

KV caching is most beneficial for autoregressive generation. For batch processing of fixed-length sequences, the benefits may be less significant.

## Conclusion

In this notebook, we've explored KV caching, a fundamental optimization technique for efficient inference in large language models. We've implemented a basic version of KV caching and measured its impact on inference speed.

Key takeaways:

1. KV caching significantly speeds up autoregressive generation by reusing previously computed key and value tensors
2. The speedup increases with the length of the generated sequence
3. KV caching trades memory for computation, with memory usage growing linearly with sequence length
4. For large models and long sequences, memory management becomes critical

In the next part, we'll explore more advanced techniques like paged attention, which addresses the memory management challenges of KV caching.