In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Case Study: Optimizing Cloud LLM Inference at Stratos AI
## Implementation Notebook

---

**Scenario:** You are an inference engineer at Stratos AI, a cloud AI platform serving LLM inference for 120+ enterprise clients. Your GPU fleet is hitting memory limits, SLA violations are mounting, and inference costs consume 43% of revenue. Your task is to build the core optimizations -- paged KV cache, GPU-side sampling, continuous batching, and multi-tenant LoRA serving -- that will improve throughput by 3x and restore healthy margins.

**Current system:** Naive KV cache with contiguous allocation, CPU-side sampling, static batching. 20 concurrent requests per GPU, 340 tokens/second, 12% SLA violation rate at peak.

**Target:** 60+ concurrent requests per GPU, 1200+ tokens/second, < 1% SLA violations.

---

## 3.1 Baseline Inference Engine

We start by building a minimal transformer model and establishing baseline performance metrics. This gives us concrete numbers to improve against.

In [None]:
# Install dependencies
!pip install -q torch matplotlib numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

torch.manual_seed(42)
np.random.seed(42)

First, let us define a small transformer model that is large enough to demonstrate the bottlenecks but small enough to run on a free Colab T4 GPU. We use dimensions inspired by a scaled-down version of Stratos's 13B model.

In [None]:
@dataclass
class ModelConfig:
    """Transformer model configuration (scaled down for Colab)."""
    vocab_size: int = 10000
    n_layers: int = 8
    n_heads: int = 8
    d_model: int = 512
    d_ff: int = 2048
    max_seq_len: int = 1024
    dropout: float = 0.0


class MultiHeadAttention(nn.Module):
    """Multi-head self-attention with optional KV cache."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_model // config.n_heads
        self.d_model = config.d_model

        self.W_q = nn.Linear(config.d_model, config.d_model, bias=False)
        self.W_k = nn.Linear(config.d_model, config.d_model, bias=False)
        self.W_v = nn.Linear(config.d_model, config.d_model, bias=False)
        self.W_o = nn.Linear(config.d_model, config.d_model, bias=False)

    def forward(self, x, kv_cache=None, use_cache=False):
        B, T, D = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        if kv_cache is not None:
            # Append new K, V to the cache
            cached_k, cached_v = kv_cache
            k = torch.cat([cached_k, k], dim=2)
            v = torch.cat([cached_v, v], dim=2)

        new_cache = (k, v) if use_cache else None

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5)

        # Causal mask
        S = k.shape[2]
        causal_mask = torch.triu(torch.ones(T, S, device=x.device), diagonal=S - T + 1).bool()
        scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        out = self.W_o(out)

        return out, new_cache


class TransformerBlock(nn.Module):
    """A single transformer decoder block."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.attn = MultiHeadAttention(config)
        self.ln1 = nn.LayerNorm(config.d_model)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.ff = nn.Sequential(
            nn.Linear(config.d_model, config.d_ff),
            nn.GELU(),
            nn.Linear(config.d_ff, config.d_model),
        )

    def forward(self, x, kv_cache=None, use_cache=False):
        # Pre-norm architecture
        h = self.ln1(x)
        attn_out, new_cache = self.attn(h, kv_cache=kv_cache, use_cache=use_cache)
        x = x + attn_out
        x = x + self.ff(self.ln2(x))
        return x, new_cache


class MiniLLM(nn.Module):
    """Minimal GPT-style language model."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self, input_ids, kv_caches=None, use_cache=False, start_pos=0):
        B, T = input_ids.shape
        positions = torch.arange(start_pos, start_pos + T, device=input_ids.device)
        x = self.token_emb(input_ids) + self.pos_emb(positions)

        new_caches = []
        for i, block in enumerate(self.blocks):
            cache = kv_caches[i] if kv_caches is not None else None
            x, new_cache = block(x, kv_cache=cache, use_cache=use_cache)
            new_caches.append(new_cache)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        return logits, new_caches if use_cache else None


# Instantiate the model
config = ModelConfig()
model = MiniLLM(config).to(device)
model.eval()

total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
print(f"Model memory (fp32): {total_params * 4 / 1e6:.1f} MB")

### TODO 1: Naive Autoregressive Generation (No KV Cache)

Implement generation without any caching. At each step, feed the entire growing sequence through the model. This is the baseline Stratos is currently running (approximately).

In [None]:
def generate_naive(model, prompt_ids, max_new_tokens=50):
    """
    Generate tokens WITHOUT KV cache.
    At each step, feed the ENTIRE sequence (prompt + generated so far) through the model.

    Args:
        model: MiniLLM instance
        prompt_ids: tensor of shape (1, prompt_len) with token indices
        max_new_tokens: number of tokens to generate

    Returns:
        generated_ids: tensor of all tokens (prompt + generated)
        elapsed_time: total generation time in seconds
    """
    model.eval()
    tokens = prompt_ids.clone()

    start_time = time.time()
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # TODO: Feed the ENTIRE tokens sequence through the model (no cache)
            # Get logits for the LAST position only
            # Use greedy decoding (argmax) to select the next token
            # Append the new token to the tokens tensor

            # --- YOUR CODE HERE ---
            logits, _ = model(tokens, use_cache=False)
            next_logit = logits[:, -1, :]
            next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
            tokens = torch.cat([tokens, next_token], dim=1)
            # --- END YOUR CODE ---

    elapsed = time.time() - start_time
    return tokens, elapsed


# Test naive generation
prompt = torch.randint(0, config.vocab_size, (1, 64), device=device)
output_naive, time_naive = generate_naive(model, prompt, max_new_tokens=50)
print(f"Naive generation: {50 / time_naive:.1f} tokens/sec, {time_naive:.3f}s total")
print(f"Output shape: {output_naive.shape}")

### TODO 2: KV Cache Generation

Now implement generation with the KV cache. On the first pass, process the full prompt and cache all K, V tensors. On subsequent passes, feed only the newest token.

In [None]:
def generate_with_cache(model, prompt_ids, max_new_tokens=50):
    """
    Generate tokens WITH KV cache.
    First pass: process full prompt, cache K/V for all layers.
    Subsequent passes: feed only the new token, append to cache.

    Args:
        model: MiniLLM instance
        prompt_ids: tensor of shape (1, prompt_len)
        max_new_tokens: number of tokens to generate

    Returns:
        generated_ids: tensor of all tokens (prompt + generated)
        elapsed_time: total generation time in seconds
    """
    model.eval()
    tokens = prompt_ids.clone()

    start_time = time.time()
    with torch.no_grad():
        # Prefill: process the entire prompt, cache K/V
        # TODO: Run the model on the full prompt with use_cache=True
        # Store the returned caches for use in the decode loop

        # --- YOUR CODE HERE ---
        logits, kv_caches = model(tokens, use_cache=True, start_pos=0)
        next_logit = logits[:, -1, :]
        next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
        tokens = torch.cat([tokens, next_token], dim=1)
        # --- END YOUR CODE ---

        # Decode: generate one token at a time using the cache
        for step in range(1, max_new_tokens):
            # TODO: Feed ONLY the last token through the model
            # Pass the kv_caches and use_cache=True
            # Set start_pos to the current sequence length - 1

            # --- YOUR CODE HERE ---
            pos = tokens.shape[1] - 1
            logits, kv_caches = model(
                tokens[:, -1:], kv_caches=kv_caches, use_cache=True, start_pos=pos
            )
            next_logit = logits[:, -1, :]
            next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
            tokens = torch.cat([tokens, next_token], dim=1)
            # --- END YOUR CODE ---

    elapsed = time.time() - start_time
    return tokens, elapsed


# Test cached generation
output_cached, time_cached = generate_with_cache(model, prompt, max_new_tokens=50)
print(f"Cached generation: {50 / time_cached:.1f} tokens/sec, {time_cached:.3f}s total")
print(f"Speedup: {time_naive / time_cached:.2f}x")

# Verify outputs match (both use greedy decoding, so they should be identical)
match = torch.equal(output_naive, output_cached)
print(f"Outputs match: {match}")

Let us benchmark both approaches across different prompt lengths to understand how the speedup scales.

In [None]:
def benchmark_generation(model, prompt_lengths, max_new_tokens=100, n_trials=3):
    """Benchmark naive vs cached generation across prompt lengths."""
    results = {'prompt_len': [], 'naive_tps': [], 'cached_tps': [], 'speedup': []}

    for plen in prompt_lengths:
        naive_times = []
        cached_times = []

        for _ in range(n_trials):
            prompt = torch.randint(0, config.vocab_size, (1, plen), device=device)

            _, t_naive = generate_naive(model, prompt, max_new_tokens=max_new_tokens)
            naive_times.append(t_naive)

            _, t_cached = generate_with_cache(model, prompt, max_new_tokens=max_new_tokens)
            cached_times.append(t_cached)

        avg_naive = np.mean(naive_times)
        avg_cached = np.mean(cached_times)

        results['prompt_len'].append(plen)
        results['naive_tps'].append(max_new_tokens / avg_naive)
        results['cached_tps'].append(max_new_tokens / avg_cached)
        results['speedup'].append(avg_naive / avg_cached)

        print(f"Prompt len {plen:4d}: naive={max_new_tokens/avg_naive:.0f} tok/s, "
              f"cached={max_new_tokens/avg_cached:.0f} tok/s, "
              f"speedup={avg_naive/avg_cached:.1f}x")

    return results

prompt_lengths = [32, 64, 128, 256, 512]
results = benchmark_generation(model, prompt_lengths, max_new_tokens=50)

In [None]:
# Visualize the benchmark results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(results['prompt_len'], results['naive_tps'], 'o-', color='#e74c3c', label='Naive (no cache)', linewidth=2)
axes[0].plot(results['prompt_len'], results['cached_tps'], 's-', color='#2ecc71', label='With KV cache', linewidth=2)
axes[0].set_xlabel('Prompt Length (tokens)')
axes[0].set_ylabel('Tokens per Second')
axes[0].set_title('Generation Throughput')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].bar(range(len(results['prompt_len'])), results['speedup'], color='#3498db')
axes[1].set_xticks(range(len(results['prompt_len'])))
axes[1].set_xticklabels(results['prompt_len'])
axes[1].set_xlabel('Prompt Length (tokens)')
axes[1].set_ylabel('Speedup Factor')
axes[1].set_title('KV Cache Speedup vs Naive')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('benchmark_kv_cache.png', dpi=150, bbox_inches='tight')
plt.show()

## 3.2 Paged KV Cache

The standard KV cache allocates a contiguous memory block for each request up to the maximum sequence length. At Stratos, this wastes 79% of allocated memory because most responses are far shorter than the 4096-token maximum.

The **paged KV cache** allocates memory in fixed-size blocks (pages), growing on demand as the sequence lengthens. When a request completes, its pages are returned to a free pool and immediately available for new requests.

### TODO 3: Paged KV Cache Implementation

Build a page table and block allocator that manages KV cache memory in fixed-size pages.

In [None]:
class PagedKVCache:
    """
    Paged KV cache that allocates memory in fixed-size blocks.

    Instead of pre-allocating max_seq_len * d_model for each request,
    we allocate pages of page_size tokens as needed.
    """

    def __init__(
        self,
        n_layers: int,
        n_heads: int,
        d_head: int,
        page_size: int = 16,
        max_pages: int = 512,
        device: str = 'cuda',
        dtype: torch.dtype = torch.float32,
    ):
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_head = d_head
        self.page_size = page_size
        self.max_pages = max_pages
        self.device = device
        self.dtype = dtype

        # TODO: Create the physical page pool
        # A single large tensor of shape (max_pages, 2, n_layers, page_size, n_heads, d_head)
        # The '2' dimension is for K and V
        # Initialize a free list tracking which pages are available

        # --- YOUR CODE HERE ---
        self.page_pool = torch.zeros(
            max_pages, 2, n_layers, page_size, n_heads, d_head,
            device=device, dtype=dtype
        )
        self.free_pages = list(range(max_pages))

        # Page tables: request_id -> list of page indices (in order)
        self.page_tables: Dict[int, List[int]] = {}
        # Track how many tokens are written in the last page of each request
        self.seq_lengths: Dict[int, int] = {}
        # --- END YOUR CODE ---

    def allocate_request(self, request_id: int):
        """Register a new request. Allocate its first page."""
        if request_id in self.page_tables:
            raise ValueError(f"Request {request_id} already exists")
        if len(self.free_pages) == 0:
            raise RuntimeError("Out of pages")

        page_idx = self.free_pages.pop(0)
        self.page_tables[request_id] = [page_idx]
        self.seq_lengths[request_id] = 0

    def append_token(self, request_id: int, layer: int, k: torch.Tensor, v: torch.Tensor):
        """
        Append a single token's K and V to the cache for the given layer.

        k, v: shape (n_heads, d_head)

        If the current page is full, allocate a new page.
        """
        # TODO: Determine the current position within the last page
        # If the page is full, allocate a new page
        # Write k and v into the correct position in the page pool

        # --- YOUR CODE HERE ---
        seq_len = self.seq_lengths[request_id]
        page_offset = seq_len % self.page_size

        # Need new page?
        if page_offset == 0 and seq_len > 0:
            if len(self.free_pages) == 0:
                raise RuntimeError("Out of pages")
            new_page = self.free_pages.pop(0)
            self.page_tables[request_id].append(new_page)

        pages = self.page_tables[request_id]
        current_page = pages[-1]

        self.page_pool[current_page, 0, layer, page_offset] = k  # Key
        self.page_pool[current_page, 1, layer, page_offset] = v  # Value

        if layer == self.n_layers - 1:
            # Only increment after the last layer writes
            self.seq_lengths[request_id] = seq_len + 1
        # --- END YOUR CODE ---

    def get_kv(self, request_id: int, layer: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieve all cached K and V for a request at a given layer.
        Returns: (k, v) each of shape (seq_len, n_heads, d_head)
        """
        seq_len = self.seq_lengths[request_id]
        pages = self.page_tables[request_id]

        # Gather from pages
        k_parts = []
        v_parts = []
        tokens_remaining = seq_len

        for page_idx in pages:
            n_tokens = min(self.page_size, tokens_remaining)
            k_parts.append(self.page_pool[page_idx, 0, layer, :n_tokens])
            v_parts.append(self.page_pool[page_idx, 1, layer, :n_tokens])
            tokens_remaining -= n_tokens

        k = torch.cat(k_parts, dim=0)  # (seq_len, n_heads, d_head)
        v = torch.cat(v_parts, dim=0)
        return k, v

    def free_request(self, request_id: int):
        """Release all pages for a completed request."""
        if request_id not in self.page_tables:
            return
        pages = self.page_tables.pop(request_id)
        self.free_pages.extend(pages)
        del self.seq_lengths[request_id]

    def memory_stats(self) -> dict:
        """Report memory usage statistics."""
        total_pages = self.max_pages
        used_pages = total_pages - len(self.free_pages)
        bytes_per_page = 2 * self.n_layers * self.page_size * self.n_heads * self.d_head * 4
        return {
            'total_pages': total_pages,
            'used_pages': used_pages,
            'free_pages': len(self.free_pages),
            'utilization': used_pages / total_pages,
            'used_memory_mb': used_pages * bytes_per_page / 1e6,
            'total_memory_mb': total_pages * bytes_per_page / 1e6,
        }


# Test the paged KV cache
paged_cache = PagedKVCache(
    n_layers=config.n_layers,
    n_heads=config.n_heads,
    d_head=config.d_model // config.n_heads,
    page_size=16,
    max_pages=256,
    device=device,
)

# Simulate 5 concurrent requests with different lengths
request_lengths = [47, 128, 15, 200, 83]

for req_id, length in enumerate(request_lengths):
    paged_cache.allocate_request(req_id)
    for t in range(length):
        for layer in range(config.n_layers):
            k = torch.randn(config.n_heads, config.d_model // config.n_heads, device=device)
            v = torch.randn(config.n_heads, config.d_model // config.n_heads, device=device)
            paged_cache.append_token(req_id, layer, k, v)

stats = paged_cache.memory_stats()
print(f"Pages used: {stats['used_pages']}/{stats['total_pages']}")
print(f"Memory used: {stats['used_memory_mb']:.1f} MB / {stats['total_memory_mb']:.1f} MB")
print(f"Utilization: {stats['utilization']:.1%}")

### TODO 4: Compare Memory Efficiency -- Contiguous vs Paged

Simulate Stratos's workload: 50 concurrent requests with sequence lengths drawn from a log-normal distribution (median 847 tokens).

In [None]:
def compare_memory_efficiency(n_requests=50, max_seq_len=4096, n_layers=8, n_heads=8, d_head=64):
    """
    Compare memory consumption between contiguous and paged KV cache.

    Contiguous: allocates max_seq_len for every request upfront.
    Paged: allocates only what is needed, in 16-token pages.
    """
    # TODO: Draw sequence lengths from a log-normal distribution
    # with parameters that give median ~847 tokens, clipped to [100, max_seq_len]
    # Compute total memory for contiguous (n_requests * max_seq_len * per-token cost)
    # Compute total memory for paged (sum of ceil(actual_len / page_size) * page_size * per-token cost)
    # Report savings

    # --- YOUR CODE HERE ---
    # Log-normal: ln(847) ~ 6.74, use sigma=0.6 for reasonable spread
    np.random.seed(42)
    seq_lengths = np.random.lognormal(mean=np.log(847), sigma=0.6, size=n_requests)
    seq_lengths = np.clip(seq_lengths, 100, max_seq_len).astype(int)

    bytes_per_token = 2 * n_layers * n_heads * d_head * 4  # K + V, all layers, fp32

    # Contiguous: every request reserves max_seq_len
    contiguous_bytes = n_requests * max_seq_len * bytes_per_token

    # Paged: each request uses ceil(actual_len / page_size) pages
    page_size = 16
    paged_bytes = 0
    for length in seq_lengths:
        n_pages = int(np.ceil(length / page_size))
        paged_bytes += n_pages * page_size * bytes_per_token

    contiguous_gb = contiguous_bytes / 1e9
    paged_gb = paged_bytes / 1e9
    savings = 1 - paged_gb / contiguous_gb
    # --- END YOUR CODE ---

    print(f"Number of requests: {n_requests}")
    print(f"Sequence length distribution: median={int(np.median(seq_lengths))}, "
          f"mean={int(np.mean(seq_lengths))}, max={int(np.max(seq_lengths))}")
    print(f"\nContiguous allocation: {contiguous_gb:.2f} GB")
    print(f"Paged allocation:     {paged_gb:.2f} GB")
    print(f"Memory savings:       {savings:.1%}")
    print(f"\nAt 80 GB per A100:")
    print(f"  Contiguous: can serve {int(80 / (contiguous_gb / n_requests))} concurrent requests")
    print(f"  Paged:      can serve {int(80 / (paged_gb / n_requests))} concurrent requests")

    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].hist(seq_lengths, bins=30, color='#3498db', alpha=0.7, edgecolor='black')
    axes[0].axvline(np.median(seq_lengths), color='#e74c3c', linestyle='--', linewidth=2,
                    label=f'Median: {int(np.median(seq_lengths))}')
    axes[0].set_xlabel('Sequence Length (tokens)')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Request Sequence Length Distribution')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    bars = axes[1].bar(['Contiguous', 'Paged'], [contiguous_gb, paged_gb],
                       color=['#e74c3c', '#2ecc71'], edgecolor='black')
    axes[1].set_ylabel('Total Memory (GB)')
    axes[1].set_title(f'KV Cache Memory: {n_requests} Concurrent Requests')
    for bar, val in zip(bars, [contiguous_gb, paged_gb]):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                     f'{val:.1f} GB', ha='center', fontsize=12, fontweight='bold')
    axes[1].grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig('memory_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

    return seq_lengths

seq_lengths = compare_memory_efficiency()

## 3.3 GPU-Optimized Sampling

At Stratos, sampling happens on the CPU after transferring logits from the GPU. This adds 2-8 ms per token. Let us fix that.

### TODO 5: Fused GPU Sampling

Implement temperature scaling and top-p sampling entirely on the GPU in a single function. No CPU transfers.

In [None]:
def sample_gpu_naive(logits, temperature=1.0, top_p=0.9):
    """
    BASELINE: CPU-side sampling (simulating Stratos's current approach).
    Transfer logits to CPU, sort, filter, sample, transfer back.
    """
    logits_cpu = logits.cpu().float()
    logits_cpu = logits_cpu / temperature
    probs = F.softmax(logits_cpu, dim=-1)

    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
    cumulative = torch.cumsum(sorted_probs, dim=-1)
    mask = cumulative - sorted_probs >= top_p
    sorted_probs[mask] = 0.0
    sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

    sampled_idx = torch.multinomial(sorted_probs, num_samples=1)
    token = sorted_indices.gather(-1, sampled_idx)
    return token.to(logits.device)


def sample_gpu_fused(logits, temperature=1.0, top_p=0.9):
    """
    TODO: Implement fused GPU sampling.
    All operations stay on GPU -- no .cpu() calls.

    Steps:
    1. Divide logits by temperature
    2. Compute softmax probabilities
    3. Sort probabilities descending
    4. Compute cumulative sum
    5. Mask probabilities beyond top-p threshold
    6. Renormalize
    7. Sample with torch.multinomial

    Args:
        logits: shape (batch_size, vocab_size) on GPU
        temperature: float > 0
        top_p: float in (0, 1]

    Returns:
        token: shape (batch_size, 1) on GPU
    """
    # --- YOUR CODE HERE ---
    logits = logits / temperature
    probs = F.softmax(logits, dim=-1)

    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
    cumulative = torch.cumsum(sorted_probs, dim=-1)
    mask = cumulative - sorted_probs >= top_p
    sorted_probs[mask] = 0.0
    sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

    sampled_idx = torch.multinomial(sorted_probs, num_samples=1)
    token = sorted_indices.gather(-1, sampled_idx)
    return token
    # --- END YOUR CODE ---


# Benchmark: CPU vs GPU sampling
vocab_size = 32000  # Realistic vocabulary size
batch_size = 1

# Warm up
dummy_logits = torch.randn(batch_size, vocab_size, device=device)
for _ in range(10):
    _ = sample_gpu_naive(dummy_logits)
    _ = sample_gpu_fused(dummy_logits)

# Time CPU sampling
n_iters = 200
if device.type == 'cuda':
    torch.cuda.synchronize()

start = time.time()
for _ in range(n_iters):
    _ = sample_gpu_naive(dummy_logits)
    if device.type == 'cuda':
        torch.cuda.synchronize()
cpu_time = (time.time() - start) / n_iters * 1000  # ms

# Time GPU sampling
start = time.time()
for _ in range(n_iters):
    _ = sample_gpu_fused(dummy_logits)
    if device.type == 'cuda':
        torch.cuda.synchronize()
gpu_time = (time.time() - start) / n_iters * 1000  # ms

print(f"CPU-side sampling: {cpu_time:.2f} ms/token")
print(f"GPU fused sampling: {gpu_time:.2f} ms/token")
print(f"Speedup: {cpu_time / gpu_time:.1f}x")
print(f"\nOver 200 generated tokens:")
print(f"  CPU total overhead: {cpu_time * 200:.0f} ms")
print(f"  GPU total overhead: {gpu_time * 200:.0f} ms")
print(f"  Savings: {(cpu_time - gpu_time) * 200:.0f} ms")

### TODO 6: Compare Sampling Strategies

Generate completions under different sampling strategies and compare diversity, latency, and distribution shape.

In [None]:
def compare_sampling_strategies(model, prompt, n_samples=50, max_new_tokens=30):
    """
    Generate multiple completions with different sampling strategies.
    Compare output diversity and latency.
    """
    strategies = {
        'Greedy': {'temperature': 1.0, 'top_p': 1.0, 'greedy': True},
        'Temp=0.7': {'temperature': 0.7, 'top_p': 1.0, 'greedy': False},
        'Top-k=40': {'temperature': 1.0, 'top_p': 1.0, 'top_k': 40, 'greedy': False},
        'Top-p=0.9': {'temperature': 0.8, 'top_p': 0.9, 'greedy': False},
    }

    results = {}

    for name, params in strategies.items():
        all_tokens = []
        times = []

        for _ in range(n_samples):
            tokens = prompt.clone()
            start = time.time()

            with torch.no_grad():
                logits, kv_caches = model(tokens, use_cache=True, start_pos=0)

                for step in range(max_new_tokens):
                    last_logits = logits[:, -1, :]

                    if params.get('greedy', False):
                        next_token = torch.argmax(last_logits, dim=-1, keepdim=True)
                    else:
                        scaled_logits = last_logits / params['temperature']

                        if 'top_k' in params:
                            top_k = params['top_k']
                            top_vals, _ = torch.topk(scaled_logits, top_k)
                            threshold = top_vals[:, -1:]
                            scaled_logits[scaled_logits < threshold] = float('-inf')

                        next_token = sample_gpu_fused(scaled_logits, temperature=1.0, top_p=params['top_p'])

                    tokens = torch.cat([tokens, next_token], dim=1)
                    pos = tokens.shape[1] - 1
                    logits, kv_caches = model(
                        tokens[:, -1:], kv_caches=kv_caches, use_cache=True, start_pos=pos
                    )

            elapsed = time.time() - start
            times.append(elapsed)
            all_tokens.append(tokens[0, prompt.shape[1]:].cpu().tolist())

        # Compute diversity: number of unique bigrams across all samples
        all_bigrams = set()
        for seq in all_tokens:
            for i in range(len(seq) - 1):
                all_bigrams.add((seq[i], seq[i+1]))

        unique_sequences = len(set(tuple(seq) for seq in all_tokens))

        results[name] = {
            'unique_bigrams': len(all_bigrams),
            'unique_sequences': unique_sequences,
            'avg_latency_ms': np.mean(times) * 1000,
            'p99_latency_ms': np.percentile(times, 99) * 1000,
        }

        print(f"{name:12s}: {unique_sequences:3d} unique sequences, "
              f"{len(all_bigrams):5d} unique bigrams, "
              f"latency={np.mean(times)*1000:.1f}ms avg / {np.percentile(times, 99)*1000:.1f}ms p99")

    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    names = list(results.keys())
    bigrams = [results[n]['unique_bigrams'] for n in names]
    latencies = [results[n]['avg_latency_ms'] for n in names]

    colors = ['#95a5a6', '#e74c3c', '#f39c12', '#2ecc71']
    axes[0].bar(names, bigrams, color=colors, edgecolor='black')
    axes[0].set_ylabel('Unique Bigrams')
    axes[0].set_title(f'Output Diversity ({n_samples} samples, {max_new_tokens} tokens each)')
    axes[0].grid(True, alpha=0.3, axis='y')

    axes[1].bar(names, latencies, color=colors, edgecolor='black')
    axes[1].set_ylabel('Avg Latency (ms)')
    axes[1].set_title('Generation Latency per Sample')
    axes[1].grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig('sampling_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

    return results

prompt = torch.randint(0, config.vocab_size, (1, 32), device=device)
sampling_results = compare_sampling_strategies(model, prompt, n_samples=30, max_new_tokens=20)

## 3.4 Continuous Batching

Static batching waits for the entire batch to finish before starting a new batch. This means a 10-token response holds up a slot until the slowest 500-token response finishes. Continuous batching inserts and removes requests at every decode iteration.

### TODO 7: Continuous Batching Scheduler

Build a scheduler that manages a dynamic batch, inserting new requests and removing completed ones at every decode step.

In [None]:
@dataclass
class InferenceRequest:
    """A single inference request in the queue."""
    request_id: int
    prompt_ids: torch.Tensor  # shape (prompt_len,)
    max_new_tokens: int
    generated_tokens: list = None
    is_complete: bool = False
    arrival_time: float = 0.0
    first_token_time: float = 0.0
    completion_time: float = 0.0
    n_generated: int = 0

    def __post_init__(self):
        if self.generated_tokens is None:
            self.generated_tokens = []


class ContinuousBatchScheduler:
    """
    Iteration-level continuous batching scheduler.

    At each decode step:
    1. Remove completed requests
    2. Insert new requests from the waiting queue (up to max_batch_size)
    3. Run one decode step for all active requests
    """

    def __init__(self, model, max_batch_size=16, device='cuda'):
        self.model = model
        self.max_batch_size = max_batch_size
        self.device = device

        self.waiting_queue: List[InferenceRequest] = []
        self.active_batch: List[InferenceRequest] = []
        self.completed: List[InferenceRequest] = []

        # Per-request KV caches (simplified: store as list of layer caches)
        self.request_caches: Dict[int, list] = {}

    def add_request(self, request: InferenceRequest):
        """Add a new request to the waiting queue."""
        request.arrival_time = time.time()
        self.waiting_queue.append(request)

    def _prefill_request(self, request: InferenceRequest):
        """Run prefill for a single request and cache KV."""
        with torch.no_grad():
            input_ids = request.prompt_ids.unsqueeze(0).to(self.device)
            logits, kv_caches = self.model(input_ids, use_cache=True, start_pos=0)

            # Greedy decode the first token
            next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
            request.generated_tokens.append(next_token)
            request.n_generated = 1
            request.first_token_time = time.time()

            self.request_caches[request.request_id] = kv_caches

    def _decode_step(self):
        """
        Run one decode step for all active requests.

        TODO: For each active request, run the model on its latest token,
        using its cached KV state. Append the new token. Mark as complete
        if max_new_tokens is reached or EOS is generated.
        """
        # --- YOUR CODE HERE ---
        to_remove = []

        for request in self.active_batch:
            if request.is_complete:
                continue

            last_token = torch.tensor(
                [[request.generated_tokens[-1]]], device=self.device
            )
            pos = len(request.prompt_ids) + request.n_generated

            with torch.no_grad():
                logits, new_caches = self.model(
                    last_token,
                    kv_caches=self.request_caches[request.request_id],
                    use_cache=True,
                    start_pos=pos,
                )

            self.request_caches[request.request_id] = new_caches

            next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
            request.generated_tokens.append(next_token)
            request.n_generated += 1

            if request.n_generated >= request.max_new_tokens:
                request.is_complete = True
                request.completion_time = time.time()
                to_remove.append(request)

        # Remove completed requests
        for request in to_remove:
            self.active_batch.remove(request)
            self.completed.append(request)
            del self.request_caches[request.request_id]
        # --- END YOUR CODE ---

    def _admit_requests(self):
        """Move requests from waiting queue to active batch."""
        while self.waiting_queue and len(self.active_batch) < self.max_batch_size:
            request = self.waiting_queue.pop(0)
            self._prefill_request(request)
            self.active_batch.append(request)

    def run(self, max_iterations=5000):
        """Run the scheduler until all requests are complete or max iterations reached."""
        for _ in range(max_iterations):
            # Admit new requests
            self._admit_requests()

            if not self.active_batch and not self.waiting_queue:
                break

            # Run one decode step
            self._decode_step()

        return self.completed


# Test the scheduler
scheduler = ContinuousBatchScheduler(model, max_batch_size=8, device=device)

# Create 20 requests with varying generation lengths
n_test_requests = 20
for i in range(n_test_requests):
    max_tokens = np.random.randint(10, 80)
    prompt = torch.randint(0, config.vocab_size, (32,))
    req = InferenceRequest(
        request_id=i,
        prompt_ids=prompt,
        max_new_tokens=max_tokens,
    )
    scheduler.add_request(req)

print(f"Submitted {n_test_requests} requests to continuous batcher")
completed = scheduler.run()
print(f"Completed: {len(completed)} requests")

total_tokens = sum(r.n_generated for r in completed)
total_time = max(r.completion_time for r in completed) - min(r.arrival_time for r in completed)
print(f"Total tokens generated: {total_tokens}")
print(f"Wall clock time: {total_time:.2f}s")
print(f"Throughput: {total_tokens / total_time:.0f} tokens/sec")

### TODO 8: Throughput Under Load -- Static vs Continuous Batching

Simulate a realistic Poisson arrival process and compare static batching against continuous batching.

In [None]:
def run_static_batching(model, requests, batch_size=8, device='cuda'):
    """
    Static batching: collect a full batch, process all to completion,
    then start the next batch.
    """
    completed = []
    i = 0
    start_time = time.time()

    while i < len(requests):
        batch = requests[i:i + batch_size]

        # Process each request independently (worst case for static batching)
        for req in batch:
            tokens = req.prompt_ids.unsqueeze(0).to(device)
            req.first_token_time = time.time()

            with torch.no_grad():
                logits, kv_caches = model(tokens, use_cache=True, start_pos=0)
                next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
                tokens = torch.cat([tokens, next_token], dim=1)

                for step in range(1, req.max_new_tokens):
                    pos = tokens.shape[1] - 1
                    logits, kv_caches = model(
                        tokens[:, -1:], kv_caches=kv_caches, use_cache=True, start_pos=pos
                    )
                    next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
                    tokens = torch.cat([tokens, next_token], dim=1)

            req.n_generated = req.max_new_tokens
            req.completion_time = time.time()
            req.is_complete = True
            completed.append(req)

        i += batch_size

    return completed


def benchmark_batching(model, n_requests=30, batch_size=8, device='cuda'):
    """Compare static vs continuous batching."""
    # Create identical request sets
    np.random.seed(42)
    gen_lengths = np.random.randint(10, 60, size=n_requests)

    # Static batching requests
    static_requests = []
    for i in range(n_requests):
        prompt = torch.randint(0, config.vocab_size, (32,))
        req = InferenceRequest(request_id=i, prompt_ids=prompt, max_new_tokens=int(gen_lengths[i]))
        req.arrival_time = time.time()
        static_requests.append(req)

    print("Running static batching...")
    t0 = time.time()
    static_completed = run_static_batching(model, static_requests, batch_size=batch_size, device=device)
    static_time = time.time() - t0
    static_tokens = sum(r.n_generated for r in static_completed)

    # Continuous batching requests
    continuous_requests = []
    for i in range(n_requests):
        prompt = torch.randint(0, config.vocab_size, (32,))
        req = InferenceRequest(request_id=i, prompt_ids=prompt, max_new_tokens=int(gen_lengths[i]))
        continuous_requests.append(req)

    scheduler = ContinuousBatchScheduler(model, max_batch_size=batch_size, device=device)
    for req in continuous_requests:
        scheduler.add_request(req)

    print("Running continuous batching...")
    t0 = time.time()
    cont_completed = scheduler.run()
    cont_time = time.time() - t0
    cont_tokens = sum(r.n_generated for r in cont_completed)

    print(f"\n{'Metric':<30} {'Static':>12} {'Continuous':>12} {'Improvement':>12}")
    print("-" * 68)
    print(f"{'Total tokens':<30} {static_tokens:>12,} {cont_tokens:>12,}")
    print(f"{'Wall clock time (s)':<30} {static_time:>12.2f} {cont_time:>12.2f}")
    print(f"{'Throughput (tok/s)':<30} {static_tokens/static_time:>12.0f} "
          f"{cont_tokens/cont_time:>12.0f} "
          f"{cont_tokens/cont_time / (static_tokens/static_time):>11.1f}x")

    # Latency analysis
    static_ttft = [(r.first_token_time - r.arrival_time) * 1000 for r in static_completed]
    cont_ttft = [(r.first_token_time - r.arrival_time) * 1000 for r in cont_completed]

    print(f"\n{'TTFT p50 (ms)':<30} {np.percentile(static_ttft, 50):>12.0f} "
          f"{np.percentile(cont_ttft, 50):>12.0f}")
    print(f"{'TTFT p99 (ms)':<30} {np.percentile(static_ttft, 99):>12.0f} "
          f"{np.percentile(cont_ttft, 99):>12.0f}")

    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    bars = axes[0].bar(
        ['Static', 'Continuous'],
        [static_tokens/static_time, cont_tokens/cont_time],
        color=['#e74c3c', '#2ecc71'], edgecolor='black'
    )
    axes[0].set_ylabel('Tokens / Second')
    axes[0].set_title('Throughput: Static vs Continuous Batching')
    axes[0].grid(True, alpha=0.3, axis='y')

    axes[1].boxplot(
        [static_ttft, cont_ttft],
        labels=['Static', 'Continuous'],
        patch_artist=True,
        boxprops=[dict(facecolor='#e74c3c', alpha=0.5), dict(facecolor='#2ecc71', alpha=0.5)],
    )
    axes[1].set_ylabel('Time to First Token (ms)')
    axes[1].set_title('TTFT Distribution')
    axes[1].grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig('batching_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

benchmark_batching(model, n_requests=20, batch_size=8, device=device)

## 3.5 Multi-Tenant LoRA Serving

Stratos serves 120+ clients, many with custom fine-tuned models. LoRA makes this feasible: store one base model in GPU memory and swap lightweight adapters per request.

### TODO 9: LoRA Adapter Manager

Implement a LoRA manager that stores multiple adapters and applies the correct one per request.

In [None]:
class LoRAAdapter:
    """A single LoRA adapter (A and B matrices for target layers)."""

    def __init__(self, target_layers: List[int], d_model: int, rank: int, device='cuda'):
        self.target_layers = target_layers
        self.rank = rank
        self.adapters = {}

        for layer_idx in target_layers:
            # LoRA for Q and V projections
            for proj_name in ['q', 'v']:
                key = f"layer_{layer_idx}_{proj_name}"
                A = torch.randn(rank, d_model, device=device) * 0.01
                B = torch.zeros(d_model, rank, device=device)
                self.adapters[key] = {'A': A, 'B': B}

    def get_delta(self, layer_idx: int, proj_name: str) -> Optional[torch.Tensor]:
        """Get the low-rank weight update: delta_W = B @ A."""
        key = f"layer_{layer_idx}_{proj_name}"
        if key not in self.adapters:
            return None
        A = self.adapters[key]['A']
        B = self.adapters[key]['B']
        return B @ A  # (d_model, d_model)

    def memory_bytes(self) -> int:
        """Total memory used by this adapter."""
        total = 0
        for params in self.adapters.values():
            total += params['A'].numel() * 4  # fp32
            total += params['B'].numel() * 4
        return total


class LoRAManager:
    """
    Manages multiple LoRA adapters for multi-tenant serving.

    TODO: Implement adapter registration, lookup, and application.
    """

    def __init__(self, max_adapters: int = 32):
        self.adapters: Dict[str, LoRAAdapter] = {}
        self.max_adapters = max_adapters
        self.access_order: List[str] = []  # For LRU eviction

    def register_adapter(self, client_id: str, adapter: LoRAAdapter):
        """Register a LoRA adapter for a client."""
        # --- YOUR CODE HERE ---
        if len(self.adapters) >= self.max_adapters and client_id not in self.adapters:
            # LRU eviction
            evict_id = self.access_order.pop(0)
            del self.adapters[evict_id]
            print(f"Evicted adapter for client: {evict_id}")

        self.adapters[client_id] = adapter
        if client_id in self.access_order:
            self.access_order.remove(client_id)
        self.access_order.append(client_id)
        # --- END YOUR CODE ---

    def get_adapter(self, client_id: str) -> Optional[LoRAAdapter]:
        """Retrieve adapter for a client, updating LRU order."""
        if client_id not in self.adapters:
            return None
        self.access_order.remove(client_id)
        self.access_order.append(client_id)
        return self.adapters[client_id]

    def stats(self) -> dict:
        total_memory = sum(a.memory_bytes() for a in self.adapters.values())
        return {
            'n_adapters': len(self.adapters),
            'total_memory_mb': total_memory / 1e6,
            'avg_memory_per_adapter_kb': (total_memory / max(len(self.adapters), 1)) / 1e3,
        }


# Create and register 3 client adapters with different LoRA ranks
manager = LoRAManager(max_adapters=32)
d_model = config.d_model
target_layers = list(range(config.n_layers))

clients = {
    'fintech_client_a': {'rank': 8, 'description': 'Financial Q&A'},
    'legal_client_b': {'rank': 16, 'description': 'Legal document summarization'},
    'support_client_c': {'rank': 4, 'description': 'Customer support chatbot'},
}

for client_id, info in clients.items():
    adapter = LoRAAdapter(target_layers, d_model, rank=info['rank'], device=device)
    manager.register_adapter(client_id, adapter)
    print(f"Registered {client_id} (rank={info['rank']}): {adapter.memory_bytes()/1e3:.1f} KB")

stats = manager.stats()
print(f"\nTotal adapters: {stats['n_adapters']}")
print(f"Total adapter memory: {stats['total_memory_mb']:.2f} MB")
print(f"Avg per adapter: {stats['avg_memory_per_adapter_kb']:.1f} KB")

# Compare to base model memory
base_model_mb = sum(p.numel() * 4 for p in model.parameters()) / 1e6
print(f"\nBase model memory: {base_model_mb:.1f} MB")
print(f"Adapter overhead: {stats['total_memory_mb'] / base_model_mb:.2%} of base model")

### TODO 10: End-to-End System Benchmark

Combine all optimizations and measure the final system performance against the original baseline.

In [None]:
def end_to_end_benchmark(model, n_requests=30, device='cuda'):
    """
    Compare the full baseline (naive gen, CPU sampling, static batching)
    against the optimized system (KV cache, GPU sampling, continuous batching).
    """
    np.random.seed(42)
    gen_lengths = np.random.randint(15, 60, size=n_requests)

    # === BASELINE: Naive generation, no optimizations ===
    print("Running BASELINE (naive generation, static batching)...")
    baseline_start = time.time()
    baseline_total_tokens = 0

    for i in range(n_requests):
        prompt = torch.randint(0, config.vocab_size, (1, 32), device=device)
        output, _ = generate_naive(model, prompt, max_new_tokens=int(gen_lengths[i]))
        baseline_total_tokens += int(gen_lengths[i])

    baseline_time = time.time() - baseline_start
    baseline_tps = baseline_total_tokens / baseline_time

    # === OPTIMIZED: KV cache + GPU sampling + continuous batching ===
    print("Running OPTIMIZED (KV cache, GPU sampling, continuous batching)...")
    scheduler = ContinuousBatchScheduler(model, max_batch_size=8, device=device)

    for i in range(n_requests):
        prompt = torch.randint(0, config.vocab_size, (32,))
        req = InferenceRequest(
            request_id=i,
            prompt_ids=prompt,
            max_new_tokens=int(gen_lengths[i]),
        )
        scheduler.add_request(req)

    opt_start = time.time()
    completed = scheduler.run()
    opt_time = time.time() - opt_start
    opt_total_tokens = sum(r.n_generated for r in completed)
    opt_tps = opt_total_tokens / opt_time

    # Results
    print(f"\n{'='*60}")
    print(f"{'END-TO-END RESULTS':^60}")
    print(f"{'='*60}")
    print(f"\n{'Metric':<35} {'Baseline':>10} {'Optimized':>10} {'Gain':>8}")
    print("-" * 65)
    print(f"{'Total tokens generated':<35} {baseline_total_tokens:>10,} {opt_total_tokens:>10,}")
    print(f"{'Wall clock time (s)':<35} {baseline_time:>10.2f} {opt_time:>10.2f}")
    print(f"{'Throughput (tokens/s)':<35} {baseline_tps:>10.0f} {opt_tps:>10.0f} "
          f"{opt_tps/baseline_tps:>7.1f}x")

    # Stratos projections
    print(f"\n{'--- Stratos Projection (13B model, A100 80GB) ---':^60}")
    scale_factor = opt_tps / baseline_tps
    print(f"{'Baseline throughput (tok/s/GPU)':<35} {'340':>10}")
    print(f"{'Projected optimized (tok/s/GPU)':<35} {340 * scale_factor:>10.0f}")
    print(f"{'Baseline concurrent requests':<35} {'20':>10}")
    print(f"{'Projected concurrent requests':<35} {20 * scale_factor:>10.0f}")

    # Visualization
    fig, ax = plt.subplots(figsize=(10, 6))

    categories = ['Throughput\n(tok/s)', 'Projected Concurrent\nRequests/GPU']
    baseline_vals = [baseline_tps, 20]
    optimized_vals = [opt_tps, 20 * scale_factor]

    x = np.arange(len(categories))
    width = 0.35

    bars1 = ax.bar(x - width/2, baseline_vals, width, label='Baseline',
                   color='#e74c3c', edgecolor='black')
    bars2 = ax.bar(x + width/2, optimized_vals, width, label='Optimized',
                   color='#2ecc71', edgecolor='black')

    ax.set_ylabel('Value')
    ax.set_title('End-to-End: Baseline vs Optimized Inference')
    ax.set_xticks(x)
    ax.set_xticklabels(categories)
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')

    for bar in bars1 + bars2:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2, height + 0.5,
                f'{height:.0f}', ha='center', fontsize=10, fontweight='bold')

    plt.tight_layout()
    plt.savefig('end_to_end_results.png', dpi=150, bbox_inches='tight')
    plt.show()

end_to_end_benchmark(model, n_requests=20, device=device)

## Summary

In this implementation notebook, you built the core inference optimizations that Stratos AI needs to restore their SLA compliance and improve margins:

1. **KV Cache:** Eliminated redundant computation during autoregressive generation, achieving substantial speedups that scale with prompt length.

2. **Paged KV Cache:** Replaced wasteful contiguous memory allocation with demand-paged blocks, reducing memory waste from 79% to near zero and enabling 3x more concurrent requests.

3. **GPU-Fused Sampling:** Moved temperature scaling and top-p sampling from CPU to GPU, eliminating per-token transfer overhead that accumulated to seconds over a full generation.

4. **Continuous Batching:** Replaced static batching with iteration-level scheduling, ensuring GPU slots are never wasted waiting for the slowest request in a batch.

5. **Multi-Tenant LoRA:** Demonstrated serving multiple fine-tuned model variants from a single base model, with adapter memory overhead under 1% of the base model.

These optimizations are not theoretical. Every major LLM serving system (vLLM, TensorRT-LLM, TGI) implements variants of these exact techniques. The specific numbers will differ at scale, but the principles -- eliminate redundant computation, allocate memory dynamically, keep the GPU busy, and share base model weights -- are universal.