# Lab-2.2 Part 1: KV Cache Optimization

## Objectives
- Understand KV Cache structure and memory usage
- Calculate cache size for different scenarios
- Implement cache management strategies
- Optimize for long conversations

## Estimated Time: 60-90 minutes

---
## 1. KV Cache Fundamentals

In [None]:
# Imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

### Understanding KV Cache

In transformer models, the attention mechanism requires:
- **Query (Q)**: Current token
- **Key (K)**: All previous tokens
- **Value (V)**: All previous tokens

To avoid recomputing K and V for every token, we cache them - this is the **KV Cache**.

In [None]:
# KV Cache calculation function
def calculate_kv_cache_size(
    num_layers: int,
    num_heads: int,
    head_dim: int,
    batch_size: int,
    seq_len: int,
    precision: int = 2,  # FP16 = 2 bytes
) -> Tuple[float, dict]:
    """
    Calculate KV Cache size.
    
    Formula:
    size = 2 (K+V) * batch * layers * heads * seq_len * head_dim * precision
    """
    size_bytes = (
        2  # K and V
        * batch_size
        * num_layers
        * num_heads
        * seq_len
        * head_dim
        * precision
    )
    
    size_gb = size_bytes / (1024 ** 3)
    
    breakdown = {
        'size_bytes': size_bytes,
        'size_mb': size_bytes / (1024 ** 2),
        'size_gb': size_gb,
        'per_layer_mb': size_bytes / num_layers / (1024 ** 2),
    }
    
    return size_gb, breakdown

print("KV Cache calculation function defined ✓")

### Example: Llama-2-7B

In [None]:
# Llama-2-7B configuration
llama2_7b_config = {
    'num_layers': 32,
    'num_heads': 32,
    'head_dim': 128,
}

# Calculate for different scenarios
scenarios = [
    ('Single request, short', 1, 512),
    ('Single request, long', 1, 2048),
    ('Batch 8, medium', 8, 1024),
    ('Batch 16, medium', 16, 1024),
    ('Batch 32, short', 32, 512),
]

print("KV Cache Size Analysis (Llama-2-7B, FP16)")
print("=" * 70)
print(f"{'Scenario':<30} {'Batch':<8} {'Seq Len':<10} {'Cache Size'}")
print("=" * 70)

results = []
for name, batch_size, seq_len in scenarios:
    size_gb, _ = calculate_kv_cache_size(
        batch_size=batch_size,
        seq_len=seq_len,
        **llama2_7b_config
    )
    results.append((name, batch_size, seq_len, size_gb))
    print(f"{name:<30} {batch_size:<8} {seq_len:<10} {size_gb:>6.2f} GB")

print("=" * 70)

In [None]:
# Visualize cache size scaling
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Scaling with sequence length
seq_lens = [128, 256, 512, 1024, 2048, 4096]
cache_sizes_seq = [
    calculate_kv_cache_size(batch_size=1, seq_len=s, **llama2_7b_config)[0]
    for s in seq_lens
]

ax1.plot(seq_lens, cache_sizes_seq, marker='o', linewidth=2, markersize=8)
ax1.set_xlabel('Sequence Length')
ax1.set_ylabel('KV Cache Size (GB)')
ax1.set_title('Cache Size vs Sequence Length\n(Batch=1)')
ax1.grid(True, alpha=0.3)
ax1.set_xscale('log', base=2)

# Scaling with batch size
batch_sizes = [1, 2, 4, 8, 16, 32, 64]
cache_sizes_batch = [
    calculate_kv_cache_size(batch_size=b, seq_len=1024, **llama2_7b_config)[0]
    for b in batch_sizes
]

ax2.plot(batch_sizes, cache_sizes_batch, marker='s', linewidth=2, markersize=8, color='orange')
ax2.set_xlabel('Batch Size')
ax2.set_ylabel('KV Cache Size (GB)')
ax2.set_title('Cache Size vs Batch Size\n(Seq Len=1024)')
ax2.grid(True, alpha=0.3)
ax2.set_xscale('log', base=2)

plt.tight_layout()
plt.show()

print("\n📊 Observations:")
print("- KV Cache size scales linearly with sequence length")
print("- KV Cache size scales linearly with batch size")
print("- Doubling either parameter doubles memory usage")

---
## 2. Memory Bottleneck Analysis

In [None]:
# Analyze total memory usage
def analyze_total_memory(
    model_size_gb: float,
    batch_size: int,
    seq_len: int,
    model_config: dict,
) -> dict:
    """
    Analyze total GPU memory usage.
    """
    # Model weights
    model_memory = model_size_gb
    
    # KV Cache
    kv_cache_gb, _ = calculate_kv_cache_size(
        batch_size=batch_size,
        seq_len=seq_len,
        **model_config
    )
    
    # Activations (rough estimate)
    hidden_size = model_config['num_heads'] * model_config['head_dim']
    activation_bytes = batch_size * seq_len * hidden_size * 2  # FP16
    activation_gb = activation_bytes / (1024 ** 3)
    
    # Total
    total_gb = model_memory + kv_cache_gb + activation_gb
    
    return {
        'model': model_memory,
        'kv_cache': kv_cache_gb,
        'activations': activation_gb,
        'total': total_gb,
    }

# Llama-2-7B analysis
model_size = 14.0  # GB (FP16)

print("Memory Usage Analysis (Llama-2-7B)")
print("=" * 80)

test_configs = [
    (1, 512),
    (8, 1024),
    (16, 1024),
    (32, 512),
]

for batch, seq_len in test_configs:
    mem = analyze_total_memory(model_size, batch, seq_len, llama2_7b_config)
    
    print(f"\nBatch={batch}, Seq Len={seq_len}:")
    print(f"  Model weights:  {mem['model']:6.2f} GB ({mem['model']/mem['total']*100:.1f}%)")
    print(f"  KV Cache:       {mem['kv_cache']:6.2f} GB ({mem['kv_cache']/mem['total']*100:.1f}%)")
    print(f"  Activations:    {mem['activations']:6.2f} GB ({mem['activations']/mem['total']*100:.1f}%)")
    print(f"  Total:          {mem['total']:6.2f} GB")

print("\n" + "=" * 80)
print("\n💡 Key Insight: KV Cache becomes dominant as batch/seq_len increases!")

---
## 3. Cache Management Strategies

### 3.1 Static vs Dynamic Allocation

In [None]:
# Simulate different allocation strategies
class StaticCacheManager:
    """Pre-allocate maximum size for all requests."""
    
    def __init__(self, max_seq_len: int, batch_size: int):
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size
        self.allocated_memory = self._calculate_allocation()
        
    def _calculate_allocation(self):
        """Allocate for max length."""
        return calculate_kv_cache_size(
            batch_size=self.batch_size,
            seq_len=self.max_seq_len,
            **llama2_7b_config
        )[0]
    
    def get_utilization(self, actual_seq_lens: list) -> float:
        """Calculate memory utilization."""
        used = sum(
            calculate_kv_cache_size(
                batch_size=1,
                seq_len=s,
                **llama2_7b_config
            )[0]
            for s in actual_seq_lens
        )
        return used / self.allocated_memory

class DynamicCacheManager:
    """Allocate based on actual needs."""
    
    def __init__(self):
        self.allocated_memory = 0
        
    def allocate(self, seq_len: int):
        """Allocate for specific sequence length."""
        cache_size = calculate_kv_cache_size(
            batch_size=1,
            seq_len=seq_len,
            **llama2_7b_config
        )[0]
        self.allocated_memory += cache_size
        return cache_size
    
    def get_total_allocation(self):
        return self.allocated_memory

# Compare strategies
max_len = 2048
actual_lens = [512, 256, 1024, 128, 768]  # Varied lengths

static_mgr = StaticCacheManager(max_len, len(actual_lens))
dynamic_mgr = DynamicCacheManager()

for length in actual_lens:
    dynamic_mgr.allocate(length)

static_util = static_mgr.get_utilization(actual_lens)

print("Cache Allocation Comparison")
print("=" * 60)
print(f"Actual sequence lengths: {actual_lens}")
print(f"\nStatic Allocation:")
print(f"  Allocated:    {static_mgr.allocated_memory:.2f} GB")
print(f"  Utilization:  {static_util*100:.1f}%")
print(f"  Wasted:       {(1-static_util)*100:.1f}%")
print(f"\nDynamic Allocation:")
print(f"  Allocated:    {dynamic_mgr.get_total_allocation():.2f} GB")
print(f"  Utilization:  100.0%")
print(f"  Wasted:       0.0%")
print(f"\n💾 Memory Saved: {static_mgr.allocated_memory - dynamic_mgr.get_total_allocation():.2f} GB")
print(f"📊 Efficiency Gain: {(1 - static_util)*100:.1f}%")
print("=" * 60)

### 3.2 PagedAttention (vLLM)

In [None]:
# Simulate PagedAttention block management
class PagedCacheManager:
    """Simulate vLLM's PagedAttention."""
    
    def __init__(self, block_size: int = 16, total_blocks: int = 1000):
        self.block_size = block_size  # tokens per block
        self.total_blocks = total_blocks
        self.free_blocks = list(range(total_blocks))
        self.allocations = {}  # {request_id: [block_ids]}
        
    def allocate(self, request_id: int, seq_len: int) -> list:
        """Allocate blocks for a request."""
        num_blocks_needed = (seq_len + self.block_size - 1) // self.block_size
        
        if len(self.free_blocks) < num_blocks_needed:
            raise MemoryError("Not enough blocks available")
        
        allocated = []
        for _ in range(num_blocks_needed):
            block = self.free_blocks.pop(0)
            allocated.append(block)
        
        self.allocations[request_id] = allocated
        return allocated
    
    def free(self, request_id: int):
        """Free blocks for a request."""
        if request_id in self.allocations:
            blocks = self.allocations.pop(request_id)
            self.free_blocks.extend(blocks)
    
    def get_utilization(self) -> float:
        """Get memory utilization."""
        used = self.total_blocks - len(self.free_blocks)
        return used / self.total_blocks
    
    def get_stats(self) -> dict:
        return {
            'total_blocks': self.total_blocks,
            'free_blocks': len(self.free_blocks),
            'used_blocks': self.total_blocks - len(self.free_blocks),
            'utilization': self.get_utilization(),
            'active_requests': len(self.allocations),
        }

# Test PagedAttention
paged_mgr = PagedCacheManager(block_size=16, total_blocks=1000)

print("PagedAttention Simulation")
print("=" * 60)

# Simulate requests
requests = [
    (1, 512),
    (2, 256),
    (3, 1024),
    (4, 128),
]

for req_id, seq_len in requests:
    blocks = paged_mgr.allocate(req_id, seq_len)
    print(f"Request {req_id}: {seq_len} tokens → {len(blocks)} blocks")

stats = paged_mgr.get_stats()
print(f"\nMemory Status:")
print(f"  Used blocks:   {stats['used_blocks']}/{stats['total_blocks']}")
print(f"  Utilization:   {stats['utilization']*100:.1f}%")
print(f"  Active reqs:   {stats['active_requests']}")

# Free some requests
print(f"\nFreeing requests 1 and 2...")
paged_mgr.free(1)
paged_mgr.free(2)

stats = paged_mgr.get_stats()
print(f"  Used blocks:   {stats['used_blocks']}/{stats['total_blocks']}")
print(f"  Utilization:   {stats['utilization']*100:.1f}%")
print(f"  Active reqs:   {stats['active_requests']}")

print("=" * 60)

---
## 4. MQA and GQA Optimization

### Multi-Query Attention (MQA) vs Grouped-Query Attention (GQA)

**Standard Multi-Head Attention (MHA)**:
- Each head has its own K and V
- Memory: num_heads × head_dim

**Multi-Query Attention (MQA)**:
- All heads share single K and V
- Memory: 1 × head_dim
- Reduction: num_heads × smaller

**Grouped-Query Attention (GQA)**:
- Groups of heads share K and V
- Memory: num_groups × head_dim
- Balance between MHA and MQA

In [None]:
# Compare KV Cache sizes for MHA, MQA, GQA
def calculate_kv_cache_variants(
    num_layers: int,
    num_heads: int,
    head_dim: int,
    batch_size: int,
    seq_len: int,
    num_kv_heads: int = None,  # For GQA
) -> dict:
    """
    Calculate KV Cache for MHA, MQA, and GQA.
    """
    precision = 2  # FP16
    
    # MHA (standard)
    mha_size = (
        2 * batch_size * num_layers * num_heads * seq_len * head_dim * precision
    ) / (1024 ** 3)
    
    # MQA (single KV head)
    mqa_size = (
        2 * batch_size * num_layers * 1 * seq_len * head_dim * precision
    ) / (1024 ** 3)
    
    # GQA (grouped KV heads)
    if num_kv_heads is None:
        num_kv_heads = num_heads // 4  # Default: 4 query heads per KV head
    
    gqa_size = (
        2 * batch_size * num_layers * num_kv_heads * seq_len * head_dim * precision
    ) / (1024 ** 3)
    
    return {
        'MHA': mha_size,
        'MQA': mqa_size,
        'GQA': gqa_size,
        'MQA_reduction': mha_size / mqa_size,
        'GQA_reduction': mha_size / gqa_size,
    }

# Compare for Llama-2-7B config
variants = calculate_kv_cache_variants(
    num_layers=32,
    num_heads=32,
    head_dim=128,
    batch_size=16,
    seq_len=2048,
    num_kv_heads=8,  # 4:1 ratio
)

print("Attention Variants Comparison (Batch=16, Seq=2048)")
print("=" * 70)
print(f"MHA (Multi-Head):        {variants['MHA']:6.2f} GB")
print(f"GQA (Grouped, 8 KV):     {variants['GQA']:6.2f} GB ({variants['GQA_reduction']:.1f}x reduction)")
print(f"MQA (Multi-Query, 1 KV): {variants['MQA']:6.2f} GB ({variants['MQA_reduction']:.1f}x reduction)")
print("=" * 70)

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))

methods = ['MHA\n(32 KV heads)', 'GQA\n(8 KV heads)', 'MQA\n(1 KV head)']
sizes = [variants['MHA'], variants['GQA'], variants['MQA']]
colors = ['#ff6b6b', '#ffd93d', '#51cf66']

bars = ax.bar(methods, sizes, color=colors, width=0.6)
ax.set_ylabel('KV Cache Size (GB)', fontsize=12)
ax.set_title('KV Cache Size: MHA vs GQA vs MQA', fontsize=14, fontweight='bold')
ax.set_ylim(0, max(sizes) * 1.2)
ax.grid(axis='y', alpha=0.3)

# Add value labels
for bar, size in zip(bars, sizes):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{size:.2f} GB',
            ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n💡 GQA provides excellent balance: ~4x memory reduction with minimal quality loss!")

---
## 5. Long Conversation Optimization

In [None]:
# Simulate long conversation scenario
class ConversationCacheManager:
    """Manage KV Cache for long conversations."""
    
    def __init__(self, max_context: int = 4096):
        self.max_context = max_context
        self.messages = []
        self.total_tokens = 0
        
    def add_turn(self, user_tokens: int, assistant_tokens: int):
        """Add a conversation turn."""
        self.messages.append({
            'user': user_tokens,
            'assistant': assistant_tokens,
            'total': user_tokens + assistant_tokens,
        })
        self.total_tokens += user_tokens + assistant_tokens
        
        # Truncate if exceeds max context
        while self.total_tokens > self.max_context and len(self.messages) > 1:
            removed = self.messages.pop(0)
            self.total_tokens -= removed['total']
    
    def get_cache_size(self) -> float:
        """Calculate current KV Cache size."""
        return calculate_kv_cache_size(
            batch_size=1,
            seq_len=self.total_tokens,
            **llama2_7b_config
        )[0]
    
    def get_stats(self) -> dict:
        return {
            'num_turns': len(self.messages),
            'total_tokens': self.total_tokens,
            'cache_size_gb': self.get_cache_size(),
            'utilization': self.total_tokens / self.max_context,
        }

# Simulate conversation
conv_mgr = ConversationCacheManager(max_context=4096)

print("Long Conversation Simulation")
print("=" * 70)

# Simulate 10 turns
conversation_turns = [
    (50, 100),   # Short Q, medium A
    (30, 150),   # Short Q, long A
    (100, 200),  # Long Q, long A
    (40, 80),
    (60, 120),
    (50, 100),
    (70, 140),
    (45, 90),
    (80, 160),
    (55, 110),
]

cache_sizes = []
for i, (user, assistant) in enumerate(conversation_turns, 1):
    conv_mgr.add_turn(user, assistant)
    stats = conv_mgr.get_stats()
    cache_sizes.append(stats['cache_size_gb'])
    
    print(f"Turn {i:2d}: {stats['total_tokens']:4d} tokens, "
          f"{stats['cache_size_gb']:.3f} GB cache, "
          f"{stats['utilization']*100:.1f}% utilized")

print("=" * 70)
print(f"\nFinal: {conv_mgr.get_stats()['num_turns']} turns active in memory")

In [None]:
# Visualize cache growth
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(cache_sizes)+1), cache_sizes, marker='o', linewidth=2)
plt.axhline(y=conv_mgr.get_cache_size(), color='r', linestyle='--', 
            label=f'Current: {conv_mgr.get_cache_size():.3f} GB')
plt.xlabel('Conversation Turn')
plt.ylabel('KV Cache Size (GB)')
plt.title('KV Cache Growth During Long Conversation')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

print("\n💡 Cache management is crucial for long conversations!")
print("   Strategies: truncation, summarization, or prefix caching")

---
## Summary

✅ **Completed**:
1. Calculated KV Cache sizes for various scenarios
2. Analyzed memory bottlenecks
3. Compared static vs dynamic allocation
4. Simulated PagedAttention
5. Evaluated MQA/GQA benefits
6. Optimized for long conversations

📊 **Key Findings**:
- KV Cache scales linearly with batch and sequence length
- Dynamic allocation (PagedAttention) saves 40-60% memory
- GQA reduces cache by 4x with minimal quality loss
- Long conversations require cache management strategies

➡️ **Next**: In `02-Speculative_Decoding.ipynb`, we'll learn:
- Speculative Decoding algorithm
- Draft model selection
- 1.5-3x speedup techniques

In [None]:
print("✅ Lab 2.2 Part 1 Complete!")