# Notebook 24: vLLM / SGLang Quick Start

## Inference Engineering Course

---

## Overview

Serving LLMs efficiently at scale is one of the most critical challenges in production AI. Naive approaches that process one request at a time waste enormous amounts of GPU compute. **vLLM** and **SGLang** are two state-of-the-art serving frameworks that solve this through techniques like:

- **Continuous Batching**: Dynamically adding/removing requests from a batch as they complete
- **PagedAttention**: Efficiently managing GPU memory for KV caches using virtual memory concepts
- **Optimized Kernels**: Custom CUDA kernels for attention and other operations

### What You'll Learn

| Topic | Description |
|-------|-------------|
| vLLM Installation | Setting up vLLM in Colab |
| Model Deployment | Loading and serving a small model |
| Continuous Batching | How dynamic batching improves throughput |
| Benchmarking | Measuring throughput under different configs |
| SGLang Alternative | Quick look at SGLang as an alternative |

### Prerequisites
- Basic Python knowledge
- Understanding of LLM inference (tokenization, generation)
- Google Colab with GPU runtime (T4 is sufficient)

> **Important**: Make sure to enable GPU runtime: `Runtime > Change runtime type > T4 GPU`

---

## Section 1: Understanding the Problem - Why We Need Serving Frameworks

Before diving into vLLM, let's understand **why** naive LLM serving is inefficient.

### The Naive Approach

In a naive setup:
1. A request arrives
2. The model processes the full prompt (prefill phase)
3. The model generates tokens one-by-one (decode phase)
4. Only after completion does the next request start

This means the GPU sits idle between requests and during I/O operations.

### The Continuous Batching Solution

```
Naive (Static) Batching:        Continuous Batching:
                                 
Req 1: [====PPPP====DDDD]       Req 1: [==PP==DDDD]
Req 2: [    ====PPPP====DDDD]   Req 2: [  ==PP==DDDD]
Req 3: [        ====PPPP====]   Req 3: [    ==PP==DDDD]
        ^^^^^^^^ wasted GPU      No wasted GPU cycles!
```

P = Prefill, D = Decode

---

## Section 2: Installing vLLM

In [None]:
# Step 1: Install vLLM and dependencies
# This may take a few minutes on Colab
!pip install vllm>=0.4.0 -q
!pip install matplotlib numpy pandas -q

print("Installation complete!")

In [None]:
# Step 2: Verify installation and check GPU
import torch
import subprocess

print("=" * 60)
print("ENVIRONMENT CHECK")
print("=" * 60)

# Check GPU availability
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_mem / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.1f} GB")
else:
    print("WARNING: No GPU detected! Enable GPU runtime.")
    print("Go to: Runtime > Change runtime type > T4 GPU")

# Check vLLM version
try:
    import vllm
    print(f"vLLM version: {vllm.__version__}")
except ImportError:
    print("vLLM not installed properly. Re-run the install cell.")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
print("=" * 60)

---

## Section 3: Deploying a Small Model with vLLM

We'll use a small model that fits on a T4 GPU (16GB). Good options include:
- `facebook/opt-125m` (125M params, ~500MB) - Great for demos
- `microsoft/phi-2` (2.7B params, ~5.4GB) - Good quality, fits on T4
- `TinyLlama/TinyLlama-1.1B-Chat-v1.0` (1.1B params, ~2.2GB)

We'll start with `facebook/opt-125m` for speed, then optionally try larger models.

In [None]:
from vllm import LLM, SamplingParams
import time

# ============================================================
# Load a small model with vLLM
# ============================================================
# vLLM handles all the optimization automatically:
#   - PagedAttention for memory management
#   - Continuous batching
#   - Optimized CUDA kernels

MODEL_NAME = "facebook/opt-125m"  # Small model for demo

print(f"Loading model: {MODEL_NAME}")
print("This may take a minute for first download...")

start = time.time()
llm = LLM(
    model=MODEL_NAME,
    dtype="float16",          # Use FP16 for efficiency
    gpu_memory_utilization=0.8,  # Use 80% of GPU memory
    max_model_len=512,        # Max sequence length
)
load_time = time.time() - start

print(f"\nModel loaded in {load_time:.1f} seconds!")
print(f"Model: {MODEL_NAME}")

In [None]:
# ============================================================
# Generate text with vLLM
# ============================================================

# Define sampling parameters
sampling_params = SamplingParams(
    temperature=0.7,      # Controls randomness
    top_p=0.9,           # Nucleus sampling
    max_tokens=100,      # Max tokens to generate
)

# Single prompt test
prompts = [
    "The future of artificial intelligence is",
    "In a world where robots can think,",
    "The most important invention of the 21st century",
]

print("Generating responses...")
start = time.time()
outputs = llm.generate(prompts, sampling_params)
gen_time = time.time() - start

print(f"\nGenerated {len(prompts)} responses in {gen_time:.2f}s")
print("=" * 70)

total_tokens = 0
for output in outputs:
    prompt = output.prompt
    generated = output.outputs[0].text
    num_tokens = len(output.outputs[0].token_ids)
    total_tokens += num_tokens
    print(f"\nPrompt: {prompt}")
    print(f"Output ({num_tokens} tokens): {generated}")
    print("-" * 70)

print(f"\nTotal tokens generated: {total_tokens}")
print(f"Throughput: {total_tokens / gen_time:.1f} tokens/sec")

---

## Section 4: Benchmarking - Throughput vs Batch Size

One of vLLM's key advantages is its ability to handle batched requests efficiently. Let's measure throughput at different batch sizes to see how it scales.

### Key Metrics
- **Throughput**: Total tokens generated per second (tokens/s)
- **Latency**: Time per request
- **Tokens per request**: How many tokens each request generates

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# ============================================================
# Benchmark: Throughput at different batch sizes
# ============================================================

batch_sizes = [1, 2, 4, 8, 16, 32]
base_prompts = [
    "Explain the concept of machine learning in simple terms.",
    "What are the benefits of renewable energy sources?",
    "Describe the process of photosynthesis step by step.",
    "How does the internet work at a basic level?",
    "What are the key principles of good software design?",
    "Explain quantum computing to a five-year-old.",
    "What is the theory of relativity about?",
    "How do neural networks learn from data?",
]

sampling_params = SamplingParams(
    temperature=0.7,
    max_tokens=64,
)

results = {
    'batch_size': [],
    'throughput_tps': [],
    'avg_latency_ms': [],
    'total_tokens': [],
    'total_time_s': [],
}

print("Running throughput benchmark...")
print("=" * 60)

for bs in batch_sizes:
    # Create prompts for this batch size (cycle through base prompts)
    prompts = [base_prompts[i % len(base_prompts)] for i in range(bs)]
    
    # Warm-up run
    _ = llm.generate(prompts[:1], sampling_params)
    
    # Timed run (average over 3 iterations)
    times = []
    token_counts = []
    
    for _ in range(3):
        start = time.time()
        outputs = llm.generate(prompts, sampling_params)
        elapsed = time.time() - start
        times.append(elapsed)
        
        total_tok = sum(len(o.outputs[0].token_ids) for o in outputs)
        token_counts.append(total_tok)
    
    avg_time = np.mean(times)
    avg_tokens = np.mean(token_counts)
    throughput = avg_tokens / avg_time
    avg_latency = (avg_time / bs) * 1000  # ms per request
    
    results['batch_size'].append(bs)
    results['throughput_tps'].append(throughput)
    results['avg_latency_ms'].append(avg_latency)
    results['total_tokens'].append(avg_tokens)
    results['total_time_s'].append(avg_time)
    
    print(f"Batch {bs:>3d}: {throughput:>8.1f} tok/s | "
          f"Latency: {avg_latency:>7.1f} ms/req | "
          f"Tokens: {avg_tokens:>5.0f} | Time: {avg_time:.3f}s")

print("=" * 60)
print("Benchmark complete!")

In [None]:
# ============================================================
# Visualize: Throughput and Latency vs Batch Size
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Throughput vs Batch Size
axes[0].plot(results['batch_size'], results['throughput_tps'], 
             'b-o', linewidth=2, markersize=8, label='vLLM')
axes[0].set_xlabel('Batch Size', fontsize=12)
axes[0].set_ylabel('Throughput (tokens/sec)', fontsize=12)
axes[0].set_title('Throughput vs Batch Size', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].set_xscale('log', base=2)
axes[0].set_xticks(results['batch_size'])
axes[0].get_xaxis().set_major_formatter(plt.ScalarFormatter())

# Plot 2: Average Latency per Request
axes[1].plot(results['batch_size'], results['avg_latency_ms'], 
             'r-s', linewidth=2, markersize=8)
axes[1].set_xlabel('Batch Size', fontsize=12)
axes[1].set_ylabel('Avg Latency per Request (ms)', fontsize=12)
axes[1].set_title('Latency per Request vs Batch Size', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].set_xscale('log', base=2)
axes[1].set_xticks(results['batch_size'])
axes[1].get_xaxis().set_major_formatter(plt.ScalarFormatter())

# Plot 3: Throughput Scaling Efficiency
base_throughput = results['throughput_tps'][0]
scaling_efficiency = [t / (base_throughput * bs) * 100 
                      for t, bs in zip(results['throughput_tps'], results['batch_size'])]
ideal = [100] * len(results['batch_size'])

axes[2].bar(range(len(results['batch_size'])), scaling_efficiency, 
            color='steelblue', alpha=0.7, label='Actual')
axes[2].axhline(y=100, color='red', linestyle='--', label='Ideal Linear', alpha=0.7)
axes[2].set_xlabel('Batch Size', fontsize=12)
axes[2].set_ylabel('Scaling Efficiency (%)', fontsize=12)
axes[2].set_title('Scaling Efficiency', fontsize=14, fontweight='bold')
axes[2].set_xticks(range(len(results['batch_size'])))
axes[2].set_xticklabels(results['batch_size'])
axes[2].legend()
axes[2].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('vllm_benchmark.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nKey Insight: vLLM's continuous batching allows throughput to scale")
print("significantly with batch size while keeping per-request latency manageable.")

---

## Section 5: Continuous Batching in Action

Continuous batching is vLLM's killer feature. Unlike **static batching** (which waits for the longest sequence in a batch), continuous batching:

1. **Immediately starts new requests** as slots open up
2. **Releases memory** as soon as a sequence finishes
3. **Maximizes GPU utilization** at all times

### How PagedAttention Enables This

```
Traditional KV Cache:          PagedAttention:
                               
[Seq1 KV: ████████░░░░]       [Block1][Block2][Block3]
[Seq2 KV: ██████░░░░░░]          ↑       ↑       ↑
[     WASTED: ░░░░░░░░ ]       Seq1 → [B1,B3]  Seq2 → [B2]
                               No wasted memory!
```

Let's simulate this to understand the difference.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

# ============================================================
# Simulation: Static Batching vs Continuous Batching
# ============================================================

np.random.seed(42)

# Simulate 8 requests with varying generation lengths
num_requests = 8
prefill_times = np.random.uniform(0.5, 1.5, num_requests)  # Prefill time
decode_lengths = np.random.randint(10, 60, num_requests)     # Decode steps
decode_time_per_step = 0.1  # Time per decode step
decode_times = decode_lengths * decode_time_per_step

# ---- Static Batching Simulation ----
# Process in batches of 4, wait for longest to finish
static_batch_size = 4
static_schedule = []
current_time = 0

for batch_start in range(0, num_requests, static_batch_size):
    batch_end = min(batch_start + static_batch_size, num_requests)
    batch_prefill = max(prefill_times[batch_start:batch_end])
    batch_decode = max(decode_times[batch_start:batch_end])
    
    for i in range(batch_start, batch_end):
        static_schedule.append({
            'req': i,
            'prefill_start': current_time,
            'prefill_end': current_time + prefill_times[i],
            'decode_start': current_time + batch_prefill,
            'decode_end': current_time + batch_prefill + decode_times[i],
            'batch_end': current_time + batch_prefill + batch_decode,
        })
    current_time += batch_prefill + batch_decode

# ---- Continuous Batching Simulation ----
continuous_schedule = []
current_time = 0
arrival_offset = 0.3  # Requests arrive with small gaps

for i in range(num_requests):
    arrival = i * arrival_offset
    start = max(arrival, current_time * 0.3)  # Can start earlier
    continuous_schedule.append({
        'req': i,
        'prefill_start': start,
        'prefill_end': start + prefill_times[i],
        'decode_start': start + prefill_times[i],
        'decode_end': start + prefill_times[i] + decode_times[i],
    })

# ---- Visualization ----
fig, axes = plt.subplots(2, 1, figsize=(16, 10))

colors_prefill = plt.cm.Set3(np.linspace(0, 1, num_requests))
colors_decode = plt.cm.Set2(np.linspace(0, 1, num_requests))

# Static Batching
ax = axes[0]
for entry in static_schedule:
    i = entry['req']
    # Prefill
    ax.barh(i, entry['prefill_end'] - entry['prefill_start'],
            left=entry['prefill_start'], color='#2196F3', alpha=0.8, height=0.6)
    # Decode
    ax.barh(i, entry['decode_end'] - entry['decode_start'],
            left=entry['decode_start'], color='#4CAF50', alpha=0.8, height=0.6)
    # Wasted time (waiting for batch)
    if entry['decode_end'] < entry['batch_end']:
        ax.barh(i, entry['batch_end'] - entry['decode_end'],
                left=entry['decode_end'], color='#F44336', alpha=0.3, 
                height=0.6, hatch='//')

ax.set_xlabel('Time (seconds)', fontsize=12)
ax.set_ylabel('Request ID', fontsize=12)
ax.set_title('Static Batching (Wasteful)', fontsize=14, fontweight='bold')
ax.set_yticks(range(num_requests))
ax.grid(True, alpha=0.2, axis='x')
ax.axvline(x=static_schedule[3]['batch_end'], color='gray', 
           linestyle='--', alpha=0.5, label='Batch boundary')

# Continuous Batching
ax = axes[1]
for entry in continuous_schedule:
    i = entry['req']
    # Prefill
    ax.barh(i, entry['prefill_end'] - entry['prefill_start'],
            left=entry['prefill_start'], color='#2196F3', alpha=0.8, height=0.6)
    # Decode
    ax.barh(i, entry['decode_end'] - entry['decode_start'],
            left=entry['decode_start'], color='#4CAF50', alpha=0.8, height=0.6)

ax.set_xlabel('Time (seconds)', fontsize=12)
ax.set_ylabel('Request ID', fontsize=12)
ax.set_title('Continuous Batching (Efficient - vLLM)', fontsize=14, fontweight='bold')
ax.set_yticks(range(num_requests))
ax.grid(True, alpha=0.2, axis='x')

# Legend
legend_elements = [
    mpatches.Patch(color='#2196F3', alpha=0.8, label='Prefill'),
    mpatches.Patch(color='#4CAF50', alpha=0.8, label='Decode'),
    mpatches.Patch(color='#F44336', alpha=0.3, label='Wasted (waiting)', hatch='//'),
]
fig.legend(handles=legend_elements, loc='upper center', ncol=3, 
           fontsize=12, bbox_to_anchor=(0.5, 1.02))

plt.tight_layout()
plt.savefig('batching_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# Calculate total times
static_total = max(e['batch_end'] for e in static_schedule)
continuous_total = max(e['decode_end'] for e in continuous_schedule)

print(f"\nStatic Batching total time:     {static_total:.2f}s")
print(f"Continuous Batching total time: {continuous_total:.2f}s")
print(f"Speedup: {static_total / continuous_total:.2f}x")

---

## Section 6: Comparing Throughput - Naive vs vLLM

Let's compare the throughput of:
1. **Naive approach**: Using HuggingFace `transformers` with sequential generation
2. **vLLM**: Using optimized serving with continuous batching

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

# ============================================================
# Naive HuggingFace Generation
# ============================================================

print("Loading model with HuggingFace Transformers (naive)...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model_hf = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    torch_dtype=torch.float16,
    device_map="auto"
)

# Test prompts
test_prompts = [
    "The key to successful artificial intelligence is",
    "In the year 2050, humanity will",
    "The most fascinating thing about the ocean is",
    "Machine learning algorithms can be used to",
    "The history of computing begins with",
    "Quantum mechanics tells us that particles",
    "The best programming language for beginners is",
    "Climate change affects the planet by",
] * 2  # 16 prompts total

MAX_NEW_TOKENS = 50

# ---- Naive: Sequential Generation ----
print(f"\nGenerating {len(test_prompts)} prompts sequentially (naive)...")
naive_tokens = 0
start = time.time()

for prompt in test_prompts:
    inputs = tokenizer(prompt, return_tensors="pt").to(model_hf.device)
    with torch.no_grad():
        output = model_hf.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=0.7,
        )
    naive_tokens += output.shape[1] - inputs['input_ids'].shape[1]

naive_time = time.time() - start
naive_throughput = naive_tokens / naive_time

print(f"Naive: {naive_tokens} tokens in {naive_time:.2f}s = {naive_throughput:.1f} tok/s")

# ---- vLLM: Batched Generation ----
print(f"\nGenerating {len(test_prompts)} prompts with vLLM...")
vllm_params = SamplingParams(temperature=0.7, max_tokens=MAX_NEW_TOKENS)

start = time.time()
vllm_outputs = llm.generate(test_prompts, vllm_params)
vllm_time = time.time() - start

vllm_tokens = sum(len(o.outputs[0].token_ids) for o in vllm_outputs)
vllm_throughput = vllm_tokens / vllm_time

print(f"vLLM:  {vllm_tokens} tokens in {vllm_time:.2f}s = {vllm_throughput:.1f} tok/s")

# Cleanup HF model to save memory
del model_hf
torch.cuda.empty_cache()

speedup = vllm_throughput / naive_throughput
print(f"\nSpeedup: {speedup:.1f}x faster with vLLM!")

In [None]:
# ============================================================
# Visualization: Naive vs vLLM Comparison
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Throughput comparison
methods = ['Naive\n(HuggingFace)', 'vLLM\n(Optimized)']
throughputs = [naive_throughput, vllm_throughput]
colors = ['#E57373', '#4CAF50']

bars = axes[0].bar(methods, throughputs, color=colors, width=0.5, 
                   edgecolor='white', linewidth=2)
axes[0].set_ylabel('Throughput (tokens/sec)', fontsize=12)
axes[0].set_title('Throughput Comparison', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')

# Add value labels
for bar, val in zip(bars, throughputs):
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 5,
                f'{val:.0f}', ha='center', va='bottom', fontweight='bold', fontsize=13)

# Add speedup annotation
axes[0].annotate(f'{speedup:.1f}x faster!', 
                xy=(1, vllm_throughput), fontsize=14,
                xytext=(1.3, vllm_throughput * 0.7),
                arrowprops=dict(arrowstyle='->', color='green', lw=2),
                color='green', fontweight='bold')

# Time comparison
times = [naive_time, vllm_time]
bars2 = axes[1].bar(methods, times, color=colors, width=0.5,
                    edgecolor='white', linewidth=2)
axes[1].set_ylabel('Total Time (seconds)', fontsize=12)
axes[1].set_title('Total Generation Time', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')

for bar, val in zip(bars2, times):
    axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.1,
                f'{val:.2f}s', ha='center', va='bottom', fontweight='bold', fontsize=13)

plt.tight_layout()
plt.savefig('naive_vs_vllm.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 7: Exploring vLLM Configuration Options

vLLM has many configuration options that affect performance. Let's explore the most important ones.

In [None]:
# ============================================================
# Benchmark: Effect of max_tokens on throughput
# ============================================================

max_token_configs = [16, 32, 64, 128, 256]
token_throughputs = []
token_latencies = []

test_batch = [
    "Explain the concept of gravity in detail.",
    "What is the meaning of life according to philosophy?",
    "How do computers store and process information?",
    "Describe the water cycle in Earth's atmosphere.",
] * 4  # 16 prompts

print("Benchmarking effect of max_tokens...")
print("=" * 50)

for max_tok in max_token_configs:
    params = SamplingParams(temperature=0.7, max_tokens=max_tok)
    
    start = time.time()
    outputs = llm.generate(test_batch, params)
    elapsed = time.time() - start
    
    total = sum(len(o.outputs[0].token_ids) for o in outputs)
    tps = total / elapsed
    lat = elapsed / len(test_batch) * 1000
    
    token_throughputs.append(tps)
    token_latencies.append(lat)
    
    print(f"max_tokens={max_tok:>3d}: {tps:>8.1f} tok/s | "
          f"Latency: {lat:>6.1f} ms/req | Total tokens: {total}")

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(max_token_configs, token_throughputs, 'b-o', linewidth=2, markersize=8)
ax1.fill_between(max_token_configs, token_throughputs, alpha=0.1, color='blue')
ax1.set_xlabel('max_tokens', fontsize=12)
ax1.set_ylabel('Throughput (tokens/sec)', fontsize=12)
ax1.set_title('Throughput vs Generation Length', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

ax2.plot(max_token_configs, token_latencies, 'r-s', linewidth=2, markersize=8)
ax2.fill_between(max_token_configs, token_latencies, alpha=0.1, color='red')
ax2.set_xlabel('max_tokens', fontsize=12)
ax2.set_ylabel('Avg Latency per Request (ms)', fontsize=12)
ax2.set_title('Latency vs Generation Length', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('max_tokens_effect.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 8: SGLang Alternative Setup

**SGLang** (Structured Generation Language) is another high-performance serving framework with unique features:

| Feature | vLLM | SGLang |
|---------|------|--------|
| Continuous Batching | Yes | Yes |
| PagedAttention | Yes | Yes |
| RadixAttention (prefix caching) | Partial | Yes (native) |
| Structured Generation | Basic | Advanced |
| Programming Model | Python API | Domain-specific language |
| Multi-turn optimization | Limited | Excellent |

SGLang's key innovation is **RadixAttention**, which efficiently caches and reuses KV caches across requests that share common prefixes (like system prompts).

In [None]:
# ============================================================
# SGLang Installation (Optional - may require specific setup)
# ============================================================

# NOTE: SGLang installation can be more complex on Colab.
# This cell shows the setup process.

# Uncomment to install SGLang:
# !pip install sglang[all] -q

# For now, let's demonstrate SGLang concepts programmatically

print("="*60)
print("SGLang Conceptual Overview")
print("="*60)
print("""
SGLang provides a frontend language for structured LLM programs:

Example SGLang program:
---------------------------------------
import sglang as sgl

@sgl.function
def multi_turn_qa(s, question1, question2):
    s += sgl.system("You are a helpful assistant.")
    s += sgl.user(question1)
    s += sgl.assistant(sgl.gen("answer1", max_tokens=256))
    s += sgl.user(question2)
    s += sgl.assistant(sgl.gen("answer2", max_tokens=256))

# Run with RadixAttention optimization
state = multi_turn_qa.run(
    question1="What is Python?",
    question2="How does it compare to Java?"
)
print(state["answer1"])
print(state["answer2"])
---------------------------------------

Key advantages:
1. RadixAttention: Shares KV cache across requests with common prefixes
2. Constrained decoding: Native support for regex/JSON constraints
3. Parallelism: Fork-join execution for branching programs
""")

In [None]:
# ============================================================
# Simulate RadixAttention benefit
# ============================================================

# RadixAttention is especially beneficial when many requests share
# the same system prompt (common in production)

import numpy as np
import matplotlib.pyplot as plt

# Simulation: Time savings from prefix caching
num_requests = [10, 50, 100, 200, 500, 1000]
system_prompt_tokens = 500  # Typical system prompt length
user_prompt_tokens = 50    # Average user prompt
prefill_time_per_token = 0.001  # seconds

# Without RadixAttention: Every request processes the full prompt
no_radix_times = [
    n * (system_prompt_tokens + user_prompt_tokens) * prefill_time_per_token
    for n in num_requests
]

# With RadixAttention: System prompt cached after first request
radix_times = [
    (system_prompt_tokens + user_prompt_tokens) * prefill_time_per_token  # First request
    + (n - 1) * user_prompt_tokens * prefill_time_per_token  # Subsequent (cached prefix)
    for n in num_requests
]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Time comparison
ax1.plot(num_requests, no_radix_times, 'r-o', linewidth=2, label='Without RadixAttention')
ax1.plot(num_requests, radix_times, 'g-s', linewidth=2, label='With RadixAttention (SGLang)')
ax1.fill_between(num_requests, radix_times, no_radix_times, alpha=0.1, color='green')
ax1.set_xlabel('Number of Requests', fontsize=12)
ax1.set_ylabel('Total Prefill Time (seconds)', fontsize=12)
ax1.set_title('Prefill Time: RadixAttention Savings', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Speedup
speedups = [a/b for a, b in zip(no_radix_times, radix_times)]
ax2.bar(range(len(num_requests)), speedups, color='green', alpha=0.7)
ax2.set_xlabel('Number of Requests', fontsize=12)
ax2.set_ylabel('Speedup Factor', fontsize=12)
ax2.set_title('RadixAttention Speedup', fontsize=14, fontweight='bold')
ax2.set_xticks(range(len(num_requests)))
ax2.set_xticklabels(num_requests)
ax2.grid(True, alpha=0.3, axis='y')

for i, v in enumerate(speedups):
    ax2.text(i, v + 0.1, f'{v:.1f}x', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig('radix_attention.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nWith 1000 requests sharing a {system_prompt_tokens}-token system prompt:")
print(f"RadixAttention provides a {speedups[-1]:.1f}x speedup in prefill time!")

---

## Section 9: Summary & Key Takeaways

### What We Learned

| Concept | Key Insight |
|---------|-------------|
| **vLLM** | High-throughput serving with PagedAttention and continuous batching |
| **Continuous Batching** | Dynamically manages requests to maximize GPU utilization |
| **PagedAttention** | Virtual memory for KV caches eliminates memory waste |
| **Batch Size Scaling** | Throughput increases significantly with batch size |
| **SGLang** | Alternative with RadixAttention for prefix-heavy workloads |

### When to Use What

- **vLLM**: General-purpose high-throughput serving, well-established ecosystem
- **SGLang**: Multi-turn conversations, shared system prompts, structured generation
- **Naive HuggingFace**: Prototyping, single-request scenarios, fine-tuning

### Performance Rules of Thumb

1. **Always use a serving framework** in production (vLLM or SGLang)
2. **Batch requests** whenever possible - even small batches help
3. **Monitor GPU memory utilization** - aim for 80-90%
4. **Use prefix caching** if many requests share common prefixes

---

## Exercises

### Exercise 1: Larger Model Benchmarking
Load a larger model (e.g., `TinyLlama/TinyLlama-1.1B-Chat-v1.0`) and compare throughput with the OPT-125M model. How does model size affect throughput scaling?

### Exercise 2: Sampling Parameter Exploration
Benchmark the effect of different `temperature` and `top_p` values on generation speed. Does sampling strategy affect throughput?

### Exercise 3: Memory Utilization Analysis
Vary the `gpu_memory_utilization` parameter (0.5, 0.7, 0.9) and measure maximum batch size and throughput at each level.

### Exercise 4: Real-World Workload Simulation
Create a simulation with requests arriving at different rates (Poisson process) and measure how vLLM handles varying load levels.

In [None]:
# ============================================================
# Exercise 1 Starter Code: Larger Model Benchmarking
# ============================================================

# Uncomment and modify this code:

# LARGER_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# 
# llm_large = LLM(
#     model=LARGER_MODEL,
#     dtype="float16",
#     gpu_memory_utilization=0.85,
#     max_model_len=512,
# )
# 
# # Run the same benchmark as Section 4
# # Compare results with OPT-125M
# 
# # Your code here...

print("Complete the exercises above to deepen your understanding!")

In [None]:
# ============================================================
# Exercise 4 Starter Code: Poisson Arrival Simulation
# ============================================================

import numpy as np
import matplotlib.pyplot as plt

# Simulate request arrivals with Poisson process
np.random.seed(42)

# arrival_rate = 10  # requests per second
# duration = 30      # seconds
# 
# inter_arrival_times = np.random.exponential(1/arrival_rate, size=int(arrival_rate * duration * 1.5))
# arrival_times = np.cumsum(inter_arrival_times)
# arrival_times = arrival_times[arrival_times <= duration]
# 
# print(f"Simulated {len(arrival_times)} request arrivals over {duration}s")
# print(f"Average arrival rate: {len(arrival_times)/duration:.1f} req/s")
# 
# # Visualize arrivals
# plt.figure(figsize=(12, 4))
# plt.eventplot(arrival_times, lineoffsets=0, linelengths=0.5, color='blue')
# plt.xlabel('Time (seconds)')
# plt.title('Simulated Request Arrivals (Poisson Process)')
# plt.grid(True, alpha=0.3)
# plt.show()

print("Uncomment the code above to complete Exercise 4!")