[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)

# 11. Continuous Batching vs Static vs Dynamic Batching

---

## What You'll Learn

1. **Static Batching** - The simplest approach: wait for a full batch, then process everything together
2. **Dynamic Batching** - Smarter: wait for a batch OR a timeout, whichever comes first
3. **Continuous Batching** - The state of the art: insert and remove requests at the token level
4. **Why it matters** - How batching strategy dramatically affects throughput and latency
5. **Visual comparison** - Gantt-chart style timelines showing exactly what happens to each request

---

### The Core Problem

When an LLM inference server receives requests, it faces a scheduling problem:
- Requests arrive at **random times**
- Each request generates a **different number of tokens**
- GPU utilization is maximized when processing **multiple requests simultaneously** (batching)
- But users hate **waiting** for their turn

The batching strategy determines the tradeoff between throughput and latency.

In [None]:
# Install dependencies
!pip install matplotlib numpy -q

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import time
from dataclasses import dataclass, field
from typing import List, Optional
import random

# Set random seed for reproducibility
np.random.seed(42)
random.seed(42)

# Nice plot defaults
plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

print("Setup complete!")

## Part 1: Defining the Problem

Let's first model our inference workload. Each request:
- **Arrives** at some time
- Has an **input sequence** (prompt) of some length
- Needs to **generate** some number of output tokens
- Each token takes a certain amount of time to generate

We'll simulate this with a request queue.

In [None]:
@dataclass
class Request:
    """Represents a single inference request."""
    id: int
    arrival_time: float        # When the request arrives (seconds)
    input_tokens: int          # Number of input (prompt) tokens
    output_tokens: int         # Number of output tokens to generate
    
    # Filled in during simulation
    start_time: float = 0.0    # When processing actually begins
    end_time: float = 0.0      # When all tokens are generated
    tokens_generated: int = 0  # Tracking progress
    
    @property
    def wait_time(self):
        """Time spent waiting before processing starts."""
        return self.start_time - self.arrival_time
    
    @property
    def total_latency(self):
        """Total time from arrival to completion."""
        return self.end_time - self.arrival_time
    
    @property
    def processing_time(self):
        """Time spent actually generating tokens."""
        return self.end_time - self.start_time


def generate_request_queue(n_requests=20, 
                           avg_arrival_gap=0.3,
                           min_output=10, max_output=80):
    """Generate a realistic request queue with varying arrival times and lengths."""
    requests = []
    current_time = 0.0
    
    for i in range(n_requests):
        # Poisson-like arrival times
        gap = np.random.exponential(avg_arrival_gap)
        current_time += gap
        
        req = Request(
            id=i,
            arrival_time=round(current_time, 3),
            input_tokens=np.random.randint(20, 200),
            output_tokens=np.random.randint(min_output, max_output)
        )
        requests.append(req)
    
    return requests

# Generate our workload
requests = generate_request_queue(n_requests=20)

print(f"Generated {len(requests)} requests")
print(f"\n{'ID':>3} | {'Arrival':>8} | {'Input Tokens':>12} | {'Output Tokens':>13}")
print("-" * 50)
for r in requests[:10]:
    print(f"{r.id:>3} | {r.arrival_time:>8.3f}s | {r.input_tokens:>12} | {r.output_tokens:>13}")
print(f"... and {len(requests)-10} more requests")

## Part 2: The Time Model

In real LLM inference:
- **Prefill** (processing input tokens) takes some time proportional to input length
- **Decode** (generating each output token) takes ~constant time per token
- **Batching** means multiple requests share the GPU, but each token step takes slightly longer

We'll model this with simple timing parameters.

In [None]:
# Timing model parameters
PREFILL_TIME_PER_TOKEN = 0.001   # 1ms per input token for prefill
DECODE_TIME_PER_TOKEN = 0.02     # 20ms per output token (single request)
BATCH_OVERHEAD = 0.002           # 2ms overhead per additional request in batch
MAX_BATCH_SIZE = 8               # Maximum concurrent requests

def prefill_time(input_tokens: int) -> float:
    """Time to process input tokens (prefill phase)."""
    return input_tokens * PREFILL_TIME_PER_TOKEN

def decode_step_time(batch_size: int) -> float:
    """Time for one decode step with a given batch size.
    
    Key insight: batching is efficient because one decode step 
    generates one token for EACH request in the batch simultaneously.
    The cost increases only slightly with batch size.
    """
    return DECODE_TIME_PER_TOKEN + (batch_size - 1) * BATCH_OVERHEAD

# Show the efficiency of batching
print("Decode step time vs batch size:")
print(f"{'Batch Size':>10} | {'Step Time':>10} | {'Throughput (tok/s)':>18} | {'Speedup':>8}")
print("-" * 55)
base_throughput = 1.0 / decode_step_time(1)
for bs in range(1, MAX_BATCH_SIZE + 1):
    step_t = decode_step_time(bs)
    throughput = bs / step_t  # tokens per second across all requests
    speedup = throughput / base_throughput
    print(f"{bs:>10} | {step_t*1000:>8.1f}ms | {throughput:>18.1f} | {speedup:>7.2f}x")

**Key insight above**: With batch size 8, each step takes only ~34ms (vs 20ms for batch=1), but we generate 8 tokens instead of 1. That's a ~4.7x throughput improvement! This is why batching matters -- the GPU has enough parallel compute to handle multiple requests with minimal overhead.

## Part 3: Static Batching

The simplest approach:
1. Wait until we have `batch_size` requests queued up
2. Process them ALL together
3. **Wait for ALL of them to finish** before accepting new requests
4. The batch runs for as long as the LONGEST sequence

This is how many older systems work. The problem? If one request needs 80 tokens and another needs 10, the short request is stuck waiting for the long one to finish.

In [None]:
import copy

def simulate_static_batching(requests: List[Request], batch_size: int = 4) -> List[Request]:
    """Simulate static batching: wait for full batch, process together, wait for all to finish."""
    results = [copy.deepcopy(r) for r in requests]
    current_time = 0.0
    i = 0
    
    while i < len(results):
        # Collect a batch (wait for batch_size requests or use remaining)
        batch_end = min(i + batch_size, len(results))
        batch = results[i:batch_end]
        
        # Wait for all requests in this batch to arrive
        last_arrival = max(r.arrival_time for r in batch)
        current_time = max(current_time, last_arrival)
        
        # Prefill all requests in the batch
        for r in batch:
            r.start_time = current_time
        
        max_prefill = max(prefill_time(r.input_tokens) for r in batch)
        current_time += max_prefill
        
        # Decode: run for the MAXIMUM output tokens in the batch
        max_output = max(r.output_tokens for r in batch)
        bs = len(batch)
        
        for step in range(max_output):
            step_time = decode_step_time(bs)
            current_time += step_time
            
            # All requests get a token (even if they're already done)
            for r in batch:
                if r.tokens_generated < r.output_tokens:
                    r.tokens_generated += 1
        
        # All requests in batch finish at the same time (when longest finishes)
        for r in batch:
            r.end_time = current_time
        
        i = batch_end
    
    return results

static_results = simulate_static_batching(requests, batch_size=4)

print("Static Batching Results:")
print(f"{'ID':>3} | {'Arrive':>7} | {'Start':>7} | {'End':>7} | {'Wait':>6} | {'Total':>6} | {'Out Tokens':>10}")
print("-" * 70)
for r in static_results:
    print(f"{r.id:>3} | {r.arrival_time:>6.2f}s | {r.start_time:>6.2f}s | {r.end_time:>6.2f}s | {r.wait_time:>5.2f}s | {r.total_latency:>5.2f}s | {r.output_tokens:>10}")

## Part 4: Dynamic Batching

An improvement over static batching:
1. Wait for `batch_size` requests **OR** a timeout (e.g., 0.5 seconds)
2. Process whatever we have when the trigger fires
3. Still wait for ALL in the batch to finish before starting next batch

This reduces the wait time for the first few requests.

In [None]:
def simulate_dynamic_batching(requests: List[Request], 
                               batch_size: int = 4,
                               timeout: float = 0.5) -> List[Request]:
    """Simulate dynamic batching: wait for full batch OR timeout."""
    results = [copy.deepcopy(r) for r in requests]
    current_time = 0.0
    i = 0
    
    while i < len(results):
        # Wait for the first request to arrive
        current_time = max(current_time, results[i].arrival_time)
        batch_start_wait = current_time
        
        # Collect requests: either fill batch or hit timeout
        batch = [results[i]]
        j = i + 1
        
        while j < len(results) and len(batch) < batch_size:
            next_arrival = results[j].arrival_time
            time_waited = next_arrival - batch_start_wait
            
            if time_waited <= timeout:
                batch.append(results[j])
                current_time = max(current_time, next_arrival)
                j += 1
            else:
                # Timeout reached - process what we have
                current_time = batch_start_wait + timeout
                break
        
        # Process the batch
        for r in batch:
            r.start_time = current_time
        
        max_prefill = max(prefill_time(r.input_tokens) for r in batch)
        current_time += max_prefill
        
        max_output = max(r.output_tokens for r in batch)
        bs = len(batch)
        
        for step in range(max_output):
            step_time = decode_step_time(bs)
            current_time += step_time
            for r in batch:
                if r.tokens_generated < r.output_tokens:
                    r.tokens_generated += 1
        
        for r in batch:
            r.end_time = current_time
        
        i = j
    
    return results

dynamic_results = simulate_dynamic_batching(requests, batch_size=4, timeout=0.5)

print("Dynamic Batching Results:")
print(f"{'ID':>3} | {'Arrive':>7} | {'Start':>7} | {'End':>7} | {'Wait':>6} | {'Total':>6} | {'Out Tokens':>10}")
print("-" * 70)
for r in dynamic_results:
    print(f"{r.id:>3} | {r.arrival_time:>6.2f}s | {r.start_time:>6.2f}s | {r.end_time:>6.2f}s | {r.wait_time:>5.2f}s | {r.total_latency:>5.2f}s | {r.output_tokens:>10}")

## Part 5: Continuous Batching

The key insight of continuous batching (also called **iteration-level batching**):

**Don't wait for the whole batch to finish. As soon as one request completes, slot in a new one.**

At every decode step:
1. Generate one token for each active request in the batch
2. Remove any requests that have finished (reached their output length)
3. If there's room in the batch and requests are waiting, add them
4. Repeat

This means:
- Short requests leave quickly (don't wait for long ones)
- New requests can start immediately if there's space
- GPU stays maximally utilized

In [None]:
def simulate_continuous_batching(requests: List[Request], 
                                  max_batch_size: int = 4) -> List[Request]:
    """Simulate continuous batching: insert/remove requests at token level."""
    results = [copy.deepcopy(r) for r in requests]
    current_time = 0.0
    
    active_batch = []       # Currently processing
    waiting_queue = list(results)  # Waiting to be processed
    completed = []          # Done
    
    # Track per-step events for visualization
    step_log = []
    
    while waiting_queue or active_batch:
        # Add new requests to batch if space available and they've arrived
        while (len(active_batch) < max_batch_size and 
               waiting_queue and 
               waiting_queue[0].arrival_time <= current_time):
            req = waiting_queue.pop(0)
            req.start_time = current_time
            # Prefill this request (happens inline)
            current_time += prefill_time(req.input_tokens)
            active_batch.append(req)
        
        if not active_batch:
            # No active requests - jump to next arrival
            if waiting_queue:
                current_time = waiting_queue[0].arrival_time
                continue
            else:
                break
        
        # One decode step: generate one token for each active request
        bs = len(active_batch)
        step_time = decode_step_time(bs)
        current_time += step_time
        
        step_log.append({
            'time': current_time,
            'batch_size': bs,
            'active_ids': [r.id for r in active_batch]
        })
        
        # Generate tokens and check for completion
        newly_completed = []
        for r in active_batch:
            r.tokens_generated += 1
            if r.tokens_generated >= r.output_tokens:
                r.end_time = current_time
                newly_completed.append(r)
        
        # Remove completed requests (making room for new ones!)
        for r in newly_completed:
            active_batch.remove(r)
            completed.append(r)
    
    return results, step_log

continuous_results, step_log = simulate_continuous_batching(requests, max_batch_size=4)

print("Continuous Batching Results:")
print(f"{'ID':>3} | {'Arrive':>7} | {'Start':>7} | {'End':>7} | {'Wait':>6} | {'Total':>6} | {'Out Tokens':>10}")
print("-" * 70)
for r in continuous_results:
    print(f"{r.id:>3} | {r.arrival_time:>6.2f}s | {r.start_time:>6.2f}s | {r.end_time:>6.2f}s | {r.wait_time:>5.2f}s | {r.total_latency:>5.2f}s | {r.output_tokens:>10}")

## Part 6: Gantt Chart Visualization

Let's visualize what's happening with each strategy. A Gantt chart shows:
- Each request as a horizontal bar
- **Gray** = waiting time
- **Colored** = processing time
- The x-axis is wall clock time

In [None]:
def plot_gantt_chart(results: List[Request], title: str, ax=None):
    """Create a Gantt chart showing request timelines."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(16, 8))
    
    colors = plt.cm.Set3(np.linspace(0, 1, 12))
    
    for idx, r in enumerate(results):
        y = len(results) - idx - 1
        
        # Waiting time (gray)
        if r.wait_time > 0.01:
            ax.barh(y, r.wait_time, left=r.arrival_time, 
                    height=0.6, color='lightgray', edgecolor='gray',
                    alpha=0.7, label='Waiting' if idx == 0 else '')
        
        # Processing time (colored by output length)
        color = colors[r.id % len(colors)]
        ax.barh(y, r.processing_time, left=r.start_time,
                height=0.6, color=color, edgecolor='black',
                alpha=0.8, linewidth=0.5)
        
        # Arrival marker
        ax.plot(r.arrival_time, y, 'v', color='red', markersize=6, zorder=5)
        
        # Label with output tokens
        ax.text(r.start_time + r.processing_time / 2, y, 
                f'{r.output_tokens}t', ha='center', va='center', fontsize=7,
                fontweight='bold')
    
    ax.set_yticks(range(len(results)))
    ax.set_yticklabels([f'Req {r.id}' for r in reversed(results)], fontsize=8)
    ax.set_xlabel('Time (seconds)')
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.axvline(x=0, color='black', linewidth=0.5)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='lightgray', edgecolor='gray', label='Waiting'),
        Patch(facecolor='skyblue', edgecolor='black', label='Processing'),
        plt.Line2D([0], [0], marker='v', color='red', linestyle='None', label='Arrival')
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=9)
    
    return ax

# Plot all three strategies
fig, axes = plt.subplots(3, 1, figsize=(18, 20))

plot_gantt_chart(static_results, 'Static Batching (batch_size=4)', axes[0])
plot_gantt_chart(dynamic_results, 'Dynamic Batching (batch_size=4, timeout=0.5s)', axes[1])
plot_gantt_chart(continuous_results, 'Continuous Batching (max_batch=4)', axes[2])

plt.tight_layout()
plt.show()

## Part 7: The GPU Bubble Problem

The biggest problem with static batching is **GPU bubbles** -- wasted compute cycles.

When a short request finishes before the batch ends, the GPU slot it occupied sits idle. Let's visualize this.

In [None]:
def visualize_gpu_utilization(results: List[Request], title: str, 
                               max_batch_size: int = 4, is_continuous: bool = False):
    """Show GPU slot utilization over time."""
    fig, ax = plt.subplots(figsize=(16, 5))
    
    # Determine time range
    max_time = max(r.end_time for r in results)
    time_steps = np.linspace(0, max_time, 500)
    
    # For each time point, count active requests
    utilization = []
    for t in time_steps:
        active = sum(1 for r in results if r.start_time <= t <= r.end_time)
        utilization.append(min(active, max_batch_size))
    
    utilization = np.array(utilization)
    
    # Fill area chart
    ax.fill_between(time_steps, utilization, alpha=0.3, color='blue')
    ax.plot(time_steps, utilization, color='blue', linewidth=1.5)
    ax.axhline(y=max_batch_size, color='red', linestyle='--', 
               label=f'Max batch size ({max_batch_size})', alpha=0.7)
    
    avg_util = np.mean(utilization)
    ax.axhline(y=avg_util, color='green', linestyle=':', 
               label=f'Avg utilization ({avg_util:.1f})', alpha=0.7)
    
    ax.set_xlabel('Time (seconds)')
    ax.set_ylabel('Active GPU Slots')
    ax.set_title(f'{title} - GPU Utilization ({avg_util/max_batch_size*100:.0f}% average)', 
                 fontsize=13, fontweight='bold')
    ax.set_ylim(0, max_batch_size + 0.5)
    ax.legend()
    plt.tight_layout()
    plt.show()
    
    return avg_util / max_batch_size

print("GPU Utilization Comparison:")
print("=" * 50)

u1 = visualize_gpu_utilization(static_results, 'Static Batching')
u2 = visualize_gpu_utilization(dynamic_results, 'Dynamic Batching')
u3 = visualize_gpu_utilization(continuous_results, 'Continuous Batching', is_continuous=True)

## Part 8: Comprehensive Metrics Comparison

Let's calculate all the key metrics for each strategy side by side.

In [None]:
def compute_metrics(results: List[Request], name: str) -> dict:
    """Compute comprehensive metrics for a batching strategy."""
    wait_times = [r.wait_time for r in results]
    latencies = [r.total_latency for r in results]
    total_tokens = sum(r.output_tokens for r in results)
    total_time = max(r.end_time for r in results) - min(r.arrival_time for r in results)
    throughput = total_tokens / total_time
    
    metrics = {
        'name': name,
        'avg_wait': np.mean(wait_times),
        'max_wait': np.max(wait_times),
        'p50_wait': np.percentile(wait_times, 50),
        'p99_wait': np.percentile(wait_times, 99),
        'avg_latency': np.mean(latencies),
        'p50_latency': np.percentile(latencies, 50),
        'p99_latency': np.percentile(latencies, 99),
        'max_latency': np.max(latencies),
        'throughput': throughput,
        'total_time': total_time,
        'total_tokens': total_tokens
    }
    return metrics

m_static = compute_metrics(static_results, 'Static')
m_dynamic = compute_metrics(dynamic_results, 'Dynamic')
m_continuous = compute_metrics(continuous_results, 'Continuous')

all_metrics = [m_static, m_dynamic, m_continuous]

print(f"\n{'Metric':<25} | {'Static':>10} | {'Dynamic':>10} | {'Continuous':>10}")
print("=" * 70)
for key in ['avg_wait', 'max_wait', 'p50_wait', 'p99_wait', 
            'avg_latency', 'p50_latency', 'p99_latency', 'max_latency',
            'throughput', 'total_time']:
    label = key.replace('_', ' ').title()
    unit = 'tok/s' if key == 'throughput' else 's'
    values = [m[key] for m in all_metrics]
    best_idx = np.argmin(values) if key != 'throughput' else np.argmax(values)
    
    row = f"{label:<25} |"
    for i, v in enumerate(values):
        marker = ' *' if i == best_idx else '  '
        row += f" {v:>8.2f}{unit[0]}{marker} |"
    print(row)

print("\n* = best value for that metric")

In [None]:
# Bar chart comparison
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

strategies = ['Static', 'Dynamic', 'Continuous']
colors_bar = ['#e74c3c', '#f39c12', '#27ae60']

# Average Latency
vals = [m['avg_latency'] for m in all_metrics]
axes[0].bar(strategies, vals, color=colors_bar, edgecolor='black', alpha=0.8)
axes[0].set_title('Average Latency (lower is better)', fontweight='bold')
axes[0].set_ylabel('Seconds')
for i, v in enumerate(vals):
    axes[0].text(i, v + 0.02, f'{v:.2f}s', ha='center', fontweight='bold')

# Throughput
vals = [m['throughput'] for m in all_metrics]
axes[1].bar(strategies, vals, color=colors_bar, edgecolor='black', alpha=0.8)
axes[1].set_title('Throughput (higher is better)', fontweight='bold')
axes[1].set_ylabel('Tokens/second')
for i, v in enumerate(vals):
    axes[1].text(i, v + 0.5, f'{v:.1f}', ha='center', fontweight='bold')

# P99 Latency
vals = [m['p99_latency'] for m in all_metrics]
axes[2].bar(strategies, vals, color=colors_bar, edgecolor='black', alpha=0.8)
axes[2].set_title('P99 Latency (lower is better)', fontweight='bold')
axes[2].set_ylabel('Seconds')
for i, v in enumerate(vals):
    axes[2].text(i, v + 0.02, f'{v:.2f}s', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

## Part 9: Why Continuous Batching Dominates

Let's zoom into a specific example to see exactly why continuous batching wins.

In [None]:
# Create a deliberately illustrative example
illustrative_requests = [
    Request(id=0, arrival_time=0.0, input_tokens=50, output_tokens=10),   # Short
    Request(id=1, arrival_time=0.05, input_tokens=50, output_tokens=60),  # Long
    Request(id=2, arrival_time=0.1, input_tokens=50, output_tokens=15),   # Short
    Request(id=3, arrival_time=0.15, input_tokens=50, output_tokens=70),  # Very long
    Request(id=4, arrival_time=0.5, input_tokens=50, output_tokens=10),   # Short, arrives later
    Request(id=5, arrival_time=0.55, input_tokens=50, output_tokens=20),  # Medium, arrives later
]

static_ill = simulate_static_batching(illustrative_requests, batch_size=4)
dynamic_ill = simulate_dynamic_batching(illustrative_requests, batch_size=4, timeout=0.3)
continuous_ill, _ = simulate_continuous_batching(illustrative_requests, max_batch_size=4)

fig, axes = plt.subplots(3, 1, figsize=(16, 12))

plot_gantt_chart(static_ill, 'Static: Short requests trapped by long ones!', axes[0])
plot_gantt_chart(dynamic_ill, 'Dynamic: Slightly better with timeout', axes[1])
plot_gantt_chart(continuous_ill, 'Continuous: Short requests finish ASAP!', axes[2])

plt.tight_layout()
plt.show()

# Show the specific numbers
print("\nRequest 0 (10 tokens) latency comparison:")
print(f"  Static:     {static_ill[0].total_latency:.3f}s")
print(f"  Dynamic:    {dynamic_ill[0].total_latency:.3f}s")
print(f"  Continuous: {continuous_ill[0].total_latency:.3f}s")
print(f"\nRequest 4 (10 tokens, arrives at 0.5s) wait time:")
print(f"  Static:     {static_ill[4].wait_time:.3f}s")
print(f"  Dynamic:    {dynamic_ill[4].wait_time:.3f}s")
print(f"  Continuous: {continuous_ill[4].wait_time:.3f}s")

## Part 10: Scaling Analysis

Let's see how each strategy performs as load increases.

In [None]:
# Test with different load levels
load_levels = [0.8, 0.5, 0.3, 0.2, 0.15, 0.1, 0.08]  # avg gap between arrivals
load_labels = [f'{1/g:.1f} req/s' for g in load_levels]

results_by_strategy = {'Static': [], 'Dynamic': [], 'Continuous': []}

for gap in load_levels:
    reqs = generate_request_queue(n_requests=30, avg_arrival_gap=gap)
    
    static_r = simulate_static_batching(reqs, batch_size=4)
    dynamic_r = simulate_dynamic_batching(reqs, batch_size=4, timeout=0.3)
    continuous_r, _ = simulate_continuous_batching(reqs, max_batch_size=4)
    
    results_by_strategy['Static'].append(compute_metrics(static_r, 'Static'))
    results_by_strategy['Dynamic'].append(compute_metrics(dynamic_r, 'Dynamic'))
    results_by_strategy['Continuous'].append(compute_metrics(continuous_r, 'Continuous'))

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for name, color in zip(['Static', 'Dynamic', 'Continuous'], ['#e74c3c', '#f39c12', '#27ae60']):
    metrics_list = results_by_strategy[name]
    
    avg_lat = [m['avg_latency'] for m in metrics_list]
    p99_lat = [m['p99_latency'] for m in metrics_list]
    throughput = [m['throughput'] for m in metrics_list]
    
    axes[0].plot(load_labels, avg_lat, 'o-', color=color, label=name, linewidth=2, markersize=8)
    axes[1].plot(load_labels, p99_lat, 's-', color=color, label=name, linewidth=2, markersize=8)
    axes[2].plot(load_labels, throughput, '^-', color=color, label=name, linewidth=2, markersize=8)

axes[0].set_title('Average Latency vs Load', fontweight='bold')
axes[0].set_ylabel('Seconds')
axes[0].legend()
axes[0].tick_params(axis='x', rotation=45)

axes[1].set_title('P99 Latency vs Load', fontweight='bold')
axes[1].set_ylabel('Seconds')
axes[1].legend()
axes[1].tick_params(axis='x', rotation=45)

axes[2].set_title('Throughput vs Load', fontweight='bold')
axes[2].set_ylabel('Tokens/second')
axes[2].legend()
axes[2].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

## Part 11: Continuous Batching - Token-Level View

Let's create a detailed view showing exactly how the batch composition changes at each decode step in continuous batching.

In [None]:
def visualize_continuous_batch_slots(step_log, max_batch_size=4):
    """Visualize how batch slots are filled/freed over time in continuous batching."""
    fig, ax = plt.subplots(figsize=(18, 5))
    
    colors = plt.cm.tab20(np.linspace(0, 1, 20))
    
    for step_idx, step in enumerate(step_log):
        for slot_idx, req_id in enumerate(step['active_ids']):
            color = colors[req_id % 20]
            rect = plt.Rectangle((step_idx, slot_idx), 1, 0.8, 
                                  color=color, edgecolor='black', 
                                  linewidth=0.5, alpha=0.8)
            ax.add_patch(rect)
            ax.text(step_idx + 0.5, slot_idx + 0.4, f'R{req_id}',
                   ha='center', va='center', fontsize=6, fontweight='bold')
    
    ax.set_xlim(0, len(step_log))
    ax.set_ylim(0, max_batch_size)
    ax.set_xlabel('Decode Step', fontsize=12)
    ax.set_ylabel('Batch Slot', fontsize=12)
    ax.set_title('Continuous Batching: Batch Composition Over Time\n'
                 '(Each color = a different request, slots freed when request completes)',
                 fontsize=13, fontweight='bold')
    ax.set_yticks(np.arange(max_batch_size) + 0.4)
    ax.set_yticklabels([f'Slot {i}' for i in range(max_batch_size)])
    
    # Show utilization
    utils = [len(s['active_ids']) / max_batch_size * 100 for s in step_log]
    ax2 = ax.twinx()
    ax2.plot([i + 0.5 for i in range(len(step_log))], utils, 
             color='red', linewidth=1.5, alpha=0.5, linestyle='--')
    ax2.set_ylabel('Utilization %', color='red')
    ax2.set_ylim(0, 120)
    
    plt.tight_layout()
    plt.show()

# Re-run with illustrative example
_, step_log_ill = simulate_continuous_batching(illustrative_requests, max_batch_size=4)
visualize_continuous_batch_slots(step_log_ill, max_batch_size=4)

## Part 12: Real-World Context

In production LLM inference engines:

| Feature | Static | Dynamic | Continuous |
|---------|--------|---------|------------|
| **Used by** | Simple scripts | Triton, older systems | vLLM, TGI, TensorRT-LLM |
| **GPU Utilization** | Low (bubbles) | Medium | High |
| **Implementation** | Trivial | Moderate | Complex |
| **P99 Latency** | Very high | High | Low |
| **Throughput** | Low | Medium | High |

Continuous batching was popularized by the **Orca paper** (2022) and is now standard in:
- **vLLM** (PagedAttention + continuous batching)
- **HuggingFace TGI** 
- **NVIDIA TensorRT-LLM**
- **DeepSpeed-MII**

In [None]:
# Final summary visualization
fig, ax = plt.subplots(figsize=(12, 7))

# Scatter plot: throughput vs P99 latency
for name, color, marker in zip(
    ['Static', 'Dynamic', 'Continuous'],
    ['#e74c3c', '#f39c12', '#27ae60'],
    ['o', 's', '^']
):
    metrics_list = results_by_strategy[name]
    throughputs = [m['throughput'] for m in metrics_list]
    p99s = [m['p99_latency'] for m in metrics_list]
    
    ax.scatter(throughputs, p99s, c=color, marker=marker, s=150, 
              label=name, edgecolors='black', linewidth=1, zorder=5)
    ax.plot(throughputs, p99s, color=color, alpha=0.3, linewidth=2)

ax.set_xlabel('Throughput (tokens/second)', fontsize=13)
ax.set_ylabel('P99 Latency (seconds)', fontsize=13)
ax.set_title('Throughput vs P99 Latency: The Batching Strategy Frontier\n'
             '(Each point = different load level. Bottom-right is ideal)', 
             fontsize=14, fontweight='bold')
ax.legend(fontsize=12)

# Add annotation for ideal region
ax.annotate('Ideal: High throughput,\nLow latency', 
           xy=(ax.get_xlim()[1]*0.8, ax.get_ylim()[0]*1.5),
           fontsize=11, color='green', fontweight='bold',
           bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgreen', alpha=0.3))

plt.tight_layout()
plt.show()

## Part 13: Batch Size Sensitivity Analysis

In [None]:
# How does max batch size affect continuous batching?
batch_sizes = [1, 2, 4, 8, 16]
reqs = generate_request_queue(n_requests=30, avg_arrival_gap=0.15)

bs_results = []
for bs in batch_sizes:
    cont_r, _ = simulate_continuous_batching(reqs, max_batch_size=bs)
    bs_results.append(compute_metrics(cont_r, f'BS={bs}'))

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(batch_sizes, [m['avg_latency'] for m in bs_results], 'bo-', linewidth=2, markersize=10)
axes[0].set_xlabel('Max Batch Size')
axes[0].set_ylabel('Average Latency (s)')
axes[0].set_title('Latency vs Max Batch Size\n(Continuous Batching)', fontweight='bold')

axes[1].plot(batch_sizes, [m['throughput'] for m in bs_results], 'go-', linewidth=2, markersize=10)
axes[1].set_xlabel('Max Batch Size')
axes[1].set_ylabel('Throughput (tok/s)')
axes[1].set_title('Throughput vs Max Batch Size\n(Continuous Batching)', fontweight='bold')

plt.tight_layout()
plt.show()

print("\nKey insight: Larger batch sizes improve throughput but may increase latency")
print("because each decode step takes longer with more concurrent requests.")
print("The sweet spot depends on your hardware and latency requirements.")

## Part 14: Latency Distribution Comparison

In [None]:
# Generate a larger workload for distribution analysis
large_reqs = generate_request_queue(n_requests=100, avg_arrival_gap=0.15)

static_large = simulate_static_batching(large_reqs, batch_size=4)
dynamic_large = simulate_dynamic_batching(large_reqs, batch_size=4, timeout=0.3)
continuous_large, _ = simulate_continuous_batching(large_reqs, max_batch_size=4)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, results, name, color in zip(
    axes,
    [static_large, dynamic_large, continuous_large],
    ['Static', 'Dynamic', 'Continuous'],
    ['#e74c3c', '#f39c12', '#27ae60']
):
    latencies = [r.total_latency for r in results]
    ax.hist(latencies, bins=20, color=color, edgecolor='black', alpha=0.7)
    
    p50 = np.percentile(latencies, 50)
    p99 = np.percentile(latencies, 99)
    ax.axvline(p50, color='blue', linestyle='--', linewidth=2, label=f'P50={p50:.2f}s')
    ax.axvline(p99, color='red', linestyle='--', linewidth=2, label=f'P99={p99:.2f}s')
    
    ax.set_title(f'{name} Batching\nLatency Distribution', fontweight='bold')
    ax.set_xlabel('Total Latency (seconds)')
    ax.set_ylabel('Count')
    ax.legend()

plt.tight_layout()
plt.show()

---

## Key Takeaways

### 1. Static Batching is Simple but Wasteful
- Waits for a full batch before processing
- **All requests wait for the longest one** to finish (GPU bubbles)
- Short requests suffer from "head-of-line blocking"

### 2. Dynamic Batching Adds Timeouts
- Processes partial batches after a timeout
- Reduces wait time but still has the "wait for longest" problem during processing

### 3. Continuous Batching is the Gold Standard
- Inserts and removes requests at the **token level** (every decode step)
- Short requests finish quickly without waiting for long ones
- GPU stays maximally utilized by filling freed slots immediately
- Used by all modern inference engines (vLLM, TGI, TensorRT-LLM)

### 4. The Numbers Don't Lie
- Continuous batching typically offers **2-3x better throughput** and **significantly lower P99 latency**
- The advantage grows with more diverse request lengths and higher load

### 5. Implementation Complexity is the Tradeoff
- Continuous batching requires careful memory management (hence PagedAttention in vLLM)
- Each slot needs its own KV cache that can grow/shrink dynamically
- But the performance gains make it worth the engineering effort