# Lab 3.3.4: Speculative Decoding with SGLang

**Module:** 3.3 - Model Deployment & Inference Engines  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand how speculative decoding accelerates inference
- [ ] Set up SGLang with EAGLE-3 speculative decoding
- [ ] Measure and analyze the speedup from speculation
- [ ] Know when speculative decoding helps vs hurts

---

## üìö Prerequisites

- Completed: Tasks 12.1-12.3
- Understanding of: Autoregressive generation
- Hardware: GPU with enough memory for both draft and target models

---

## üåç Real-World Context

**The Problem with LLM Generation:**

LLMs generate text one token at a time. Each token requires a full forward pass through the model.
A 100-token response needs 100 sequential forward passes - that's slow!

**Speculative Decoding's Solution:**

What if we could "guess" the next several tokens quickly, then verify them in parallel?
That's exactly what speculative decoding does!

**Real-world impact:**
- 2-3x faster generation for well-matched tasks
- Same output quality (mathematically identical)
- Used by Google (Gemini), Meta (Llama), and many production systems

---

## üßí ELI5: Speculative Decoding

> **Imagine you're a slow but careful painter, and you have a fast but less careful assistant...**
>
> **Without speculative decoding:**
> You (the expert) paint one brushstroke, step back, think carefully, paint another brushstroke.
> Very slow, but every stroke is perfect.
>
> **With speculative decoding:**
> 1. Your assistant quickly sketches the next 5 brushstrokes (draft)
> 2. You look at all 5 at once and say "Yes, yes, no, yes, yes" (verify)
> 3. You keep the correct ones (3-4 strokes accepted!)
> 4. For the wrong one, you paint it correctly
> 5. Repeat!
>
> Even though you reject some guesses, you're still faster because you verified
> 5 strokes in the time it would have taken to paint 1!
>
> **In AI terms:**
> - **Draft model** = Fast, smaller model that guesses multiple tokens
> - **Target model** = Your actual large model that verifies
> - **Acceptance rate** = How often the draft's guesses are accepted
> - **Speedup** = Depends on acceptance rate and draft speed

---

## üîë Key Concepts

| Term | Meaning |
|------|--------|
| **Draft Model** | Small, fast model that proposes tokens |
| **Target Model** | Your main model that verifies proposals |
| **Speculation Length** | How many tokens to guess at once (typically 3-8) |
| **Acceptance Rate** | % of draft tokens accepted by target |
| **Wallclock Speedup** | Actual time saved (what you care about) |
| **EAGLE** | A learned draft head attached to the target model |

---

## Part 1: Understanding the Algorithm

Let's implement a simplified version to understand how it works.

In [None]:
import random
import time
from typing import List, Tuple
from dataclasses import dataclass

@dataclass
class SpeculationResult:
    """Result of a speculation round."""
    draft_tokens: List[str]
    accepted_tokens: List[str]
    acceptance_rate: float
    time_saved_factor: float


def simulate_speculative_decoding(
    target_vocabulary: List[str],
    sequence_length: int = 50,
    speculation_length: int = 5,
    draft_accuracy: float = 0.7,  # How often draft matches target
    draft_speed_multiplier: float = 10.0,  # How much faster draft is
) -> Tuple[List[str], dict]:
    """
    Simulate speculative decoding to understand the algorithm.
    
    Args:
        target_vocabulary: Possible tokens to generate
        sequence_length: Total tokens to generate
        speculation_length: Tokens to speculate per round
        draft_accuracy: Probability draft matches target
        draft_speed_multiplier: How much faster draft model is
    """
    
    generated = []
    stats = {
        "target_forward_passes": 0,
        "draft_forward_passes": 0,
        "tokens_accepted": 0,
        "tokens_rejected": 0,
        "speculation_rounds": 0
    }
    
    while len(generated) < sequence_length:
        stats["speculation_rounds"] += 1
        
        # Step 1: Draft model generates speculation_length tokens quickly
        draft_tokens = [random.choice(target_vocabulary) for _ in range(speculation_length)]
        stats["draft_forward_passes"] += speculation_length
        
        # Step 2: Target model verifies all draft tokens in one pass
        # (In reality, this is done with careful probability matching)
        target_tokens = [random.choice(target_vocabulary) for _ in range(speculation_length)]
        stats["target_forward_passes"] += 1  # Just one pass for all tokens!
        
        # Step 3: Accept tokens until first mismatch
        for i, (draft, target) in enumerate(zip(draft_tokens, target_tokens)):
            # Simulate acceptance probability
            if random.random() < draft_accuracy:
                # Draft matches target - accept!
                generated.append(draft)
                stats["tokens_accepted"] += 1
                if len(generated) >= sequence_length:
                    break
            else:
                # Draft wrong - use target's token and stop
                generated.append(target)
                stats["tokens_rejected"] += 1
                break  # Must restart speculation
    
    # Calculate speedup
    # Without speculation: sequence_length target forward passes
    # With speculation: target_forward_passes + draft_forward_passes/draft_speed_multiplier
    baseline_cost = sequence_length  # 1 target pass per token
    speculative_cost = (
        stats["target_forward_passes"] + 
        stats["draft_forward_passes"] / draft_speed_multiplier
    )
    
    stats["baseline_cost"] = baseline_cost
    stats["speculative_cost"] = speculative_cost
    stats["speedup"] = baseline_cost / speculative_cost
    stats["acceptance_rate"] = stats["tokens_accepted"] / (stats["tokens_accepted"] + stats["tokens_rejected"])
    
    return generated[:sequence_length], stats

In [None]:
# Run simulation
vocabulary = ["the", "a", "is", "was", "it", "that", "for", "on", "with", "as", 
              "be", "at", "by", "this", "have", "from", "or", "but", "not", "are"]

print("üìä Speculative Decoding Simulation")
print("=" * 60)

# Test different draft accuracies
for accuracy in [0.5, 0.7, 0.85, 0.95]:
    tokens, stats = simulate_speculative_decoding(
        target_vocabulary=vocabulary,
        sequence_length=100,
        speculation_length=5,
        draft_accuracy=accuracy,
        draft_speed_multiplier=10.0
    )
    
    print(f"\nDraft Accuracy: {accuracy:.0%}")
    print(f"   Acceptance Rate: {stats['acceptance_rate']:.1%}")
    print(f"   Target Forward Passes: {stats['target_forward_passes']} (baseline: 100)")
    print(f"   Speedup: {stats['speedup']:.2f}x")

### üîç Key Insights

1. **Acceptance rate is crucial**: Higher acceptance = more speedup
2. **Even 50% acceptance helps**: Because verification is parallelized
3. **Draft model must be FAST**: The speed multiplier matters a lot
4. **Returns diminish at very high accuracy**: You're already accepting most tokens

---

## Part 2: Setting Up SGLang with Speculative Decoding

SGLang supports several speculative decoding methods:
- **EAGLE**: Learned draft head
- **EAGLE-3**: Improved version with better acceptance
- **Medusa**: Multiple draft heads for different positions

In [None]:
import os
import subprocess
from pathlib import Path

# Check SGLang installation
def check_sglang():
    """Check if SGLang is available."""
    try:
        result = subprocess.run(
            ["python", "-c", "import sglang; print(sglang.__version__)"],
            capture_output=True, text=True
        )
        if result.returncode == 0:
            return result.stdout.strip()
    except:
        pass
    return None

sglang_version = check_sglang()
if sglang_version:
    print(f"‚úÖ SGLang version: {sglang_version}")
else:
    print("‚ùå SGLang not installed")
    print("\nüì¶ Install with:")
    print("   pip install sglang[all]")

### üöÄ Starting SGLang with EAGLE

SGLang supports speculative decoding with EAGLE draft models:

In [None]:
def generate_sglang_commands(
    model: str = "meta-llama/Llama-3.1-8B-Instruct",
    eagle_model: str = None,  # EAGLE draft model
    port: int = 30000,
    speculate_num_tokens: int = 5
) -> dict:
    """
    Generate SGLang startup commands.
    
    Returns dict with 'basic' and 'speculative' commands.
    """
    
    # Basic SGLang (no speculation)
    basic_cmd = f"""python -m sglang.launch_server \\
    --model {model} \\
    --port {port} \\
    --trust-remote-code"""
    
    # With EAGLE speculative decoding
    if eagle_model:
        speculative_cmd = f"""python -m sglang.launch_server \\
    --model {model} \\
    --port {port} \\
    --speculative-algorithm EAGLE \\
    --speculative-draft-model {eagle_model} \\
    --speculative-num-draft-tokens {speculate_num_tokens} \\
    --trust-remote-code"""
    else:
        # For models without EAGLE, use smaller model as draft
        speculative_cmd = f"""# Note: EAGLE models available for some popular architectures
# Check: https://huggingface.co/collections/yuhuili/eagle-models

# Example with custom draft model:
python -m sglang.launch_server \\
    --model {model} \\
    --port {port} \\
    --speculative-algorithm EAGLE \\
    --speculative-draft-model yuhuili/EAGLE-LLaMA3-Instruct-8B \\
    --speculative-num-draft-tokens {speculate_num_tokens} \\
    --trust-remote-code"""
    
    return {
        "basic": basic_cmd,
        "speculative": speculative_cmd
    }

commands = generate_sglang_commands(
    model="meta-llama/Llama-3.1-8B-Instruct",
    eagle_model="yuhuili/EAGLE-LLaMA3-Instruct-8B"
)

print("üìã SGLang Startup Commands")
print("=" * 70)
print("\nüê¢ WITHOUT Speculative Decoding:")
print(commands["basic"])
print("\nüöÄ WITH EAGLE Speculative Decoding:")
print(commands["speculative"])

### Available EAGLE Models

EAGLE draft models are pre-trained for specific target models:

| Target Model | EAGLE Draft Model |
|--------------|-------------------|
| Llama-3-8B-Instruct | yuhuili/EAGLE-LLaMA3-Instruct-8B |
| Llama-3-70B-Instruct | yuhuili/EAGLE-LLaMA3-Instruct-70B |
| Vicuna-7B | yuhuili/EAGLE-Vicuna-7B-v1.3 |
| Vicuna-13B | yuhuili/EAGLE-Vicuna-13B-v1.3 |
| Mixtral-8x7B | yuhuili/EAGLE-Mixtral-8x7B-Instruct-v0.1 |

Check [EAGLE Models Collection](https://huggingface.co/collections/yuhuili/eagle-models) for the latest.

---

## Part 3: Benchmarking Speculative Decoding

Let's create benchmarking tools to measure speedup.

In [None]:
import requests
import time
import json
from dataclasses import dataclass
from typing import Optional

@dataclass
class SpeedBenchmarkResult:
    """Result from a single speed benchmark."""
    prompt: str
    output: str
    output_tokens: int
    total_time_s: float
    tokens_per_second: float
    ttft_s: float
    is_speculative: bool


def benchmark_generation(
    server_url: str,
    prompt: str,
    max_tokens: int = 200,
    is_speculative: bool = False
) -> Optional[SpeedBenchmarkResult]:
    """
    Benchmark text generation speed.
    
    Args:
        server_url: SGLang server URL
        prompt: Input prompt
        max_tokens: Maximum tokens to generate
        is_speculative: Whether this is speculative mode (for labeling)
    """
    try:
        start_time = time.perf_counter()
        first_token_time = None
        output_chunks = []
        
        response = requests.post(
            f"{server_url}/v1/chat/completions",
            json={
                "model": "default",
                "messages": [{"role": "user", "content": prompt}],
                "max_tokens": max_tokens,
                "stream": True,
                "temperature": 0.7
            },
            stream=True,
            timeout=120
        )
        
        for line in response.iter_lines():
            if line:
                line_str = line.decode()
                if line_str.startswith("data: "):
                    data_str = line_str[6:]
                    if data_str == "[DONE]":
                        break
                    try:
                        chunk = json.loads(data_str)
                        content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
                        if content:
                            if first_token_time is None:
                                first_token_time = time.perf_counter()
                            output_chunks.append(content)
                    except:
                        pass
        
        end_time = time.perf_counter()
        output = "".join(output_chunks)
        
        # Estimate tokens (rough)
        output_tokens = len(output.split()) * 1.3
        total_time = end_time - start_time
        ttft = (first_token_time - start_time) if first_token_time else total_time
        
        return SpeedBenchmarkResult(
            prompt=prompt[:50] + "...",
            output=output[:100] + "...",
            output_tokens=int(output_tokens),
            total_time_s=total_time,
            tokens_per_second=output_tokens / total_time if total_time > 0 else 0,
            ttft_s=ttft,
            is_speculative=is_speculative
        )
        
    except Exception as e:
        print(f"Benchmark error: {e}")
        return None

In [None]:
# Benchmark prompts designed to test speculation
# Speculation works best when output is somewhat predictable

speculation_benchmark_prompts = {
    "predictable": [
        "Count from 1 to 20.",
        "List the days of the week.",
        "Recite the alphabet.",
        "List the months of the year.",
    ],
    "semi_predictable": [
        "Write a Python function to calculate factorial.",
        "Explain the water cycle step by step.",
        "Describe how to make a peanut butter sandwich.",
    ],
    "creative": [
        "Write a creative short story about a time-traveling cat.",
        "Compose a unique poem about the color blue.",
        "Invent a new word and define it.",
    ]
}

print("üìù Speculation Benchmark Prompts:")
for category, prompts in speculation_benchmark_prompts.items():
    print(f"\n   {category.upper()}:")
    for p in prompts:
        print(f"      - {p[:50]}..." if len(p) > 50 else f"      - {p}")

In [None]:
# Run comparative benchmark
SGLANG_URL = "http://localhost:30000"

def check_sglang_server(url: str) -> bool:
    """Check if SGLang server is running."""
    try:
        response = requests.get(f"{url}/v1/models", timeout=5)
        return response.status_code == 200
    except:
        return False

if check_sglang_server(SGLANG_URL):
    print("‚úÖ SGLang server is running!")
    print("\nüìä Running speculation benchmark...")
    print("=" * 60)
    
    all_results = []
    
    for category, prompts in speculation_benchmark_prompts.items():
        print(f"\nCategory: {category}")
        
        for prompt in prompts[:2]:  # First 2 per category
            result = benchmark_generation(
                SGLANG_URL, 
                prompt, 
                max_tokens=100
            )
            
            if result:
                print(f"   {prompt[:30]}... ‚Üí {result.tokens_per_second:.1f} tok/s")
                all_results.append((category, result))
    
    # Summary
    if all_results:
        print("\n" + "=" * 60)
        print("üìà SUMMARY BY CATEGORY")
        print("=" * 60)
        
        for cat in speculation_benchmark_prompts.keys():
            cat_results = [r for c, r in all_results if c == cat]
            if cat_results:
                avg_speed = sum(r.tokens_per_second for r in cat_results) / len(cat_results)
                print(f"   {cat}: {avg_speed:.1f} tokens/sec average")

else:
    print("‚ùå SGLang server is not running")
    print("\n" + "=" * 60)
    print("‚ö†Ô∏è  SIMULATED DATA - For Demonstration Only")
    print("=" * 60)
    print("\nüìù These are typical expected results with speculative decoding.")
    print("   Run SGLang to get actual measurements on your hardware.")
    print("")
    print(f"{'Category':<20} {'Without Spec':<15} {'With EAGLE':<15} {'Speedup'}")
    print("-" * 60)
    print(f"{'predictable':<20} {'45 tok/s':<15} {'110 tok/s':<15} {'2.4x'}")
    print(f"{'semi_predictable':<20} {'45 tok/s':<15} {'85 tok/s':<15} {'1.9x'}")
    print(f"{'creative':<20} {'45 tok/s':<15} {'55 tok/s':<15} {'1.2x'}")
    print("")
    print("üí° Key insight: Speculative decoding helps most when output is predictable!")
    print("   Start SGLang with EAGLE to measure actual speedup on your workload.")

### üîç Understanding the Results

**Why does speedup vary by prompt type?**

1. **Predictable outputs** (counting, lists): Draft model easily predicts next tokens ‚Üí high acceptance rate ‚Üí big speedup

2. **Semi-predictable** (code, explanations): Many patterns are predictable (function definitions, common phrases) ‚Üí moderate speedup

3. **Creative outputs** (stories, poems): Hard to predict ‚Üí low acceptance rate ‚Üí less speedup (but never slower!)

---

## Part 4: When to Use Speculative Decoding

Let's create a decision framework.

In [None]:
# Decision framework for speculative decoding

use_cases = {
    "strongly_recommended": {
        "title": "Strongly Recommended",
        "examples": [
            "Code completion (predictable syntax)",
            "Structured data generation (JSON, YAML)",
            "Translation (predictable patterns)",
            "Summarization (extractive style)",
            "Question answering (factual)",
        ],
        "expected_speedup": "2-3x"
    },
    "recommended": {
        "title": "Recommended",
        "examples": [
            "Technical explanations",
            "Instructions/how-to guides",
            "Email drafting",
            "Documentation writing",
        ],
        "expected_speedup": "1.5-2x"
    },
    "worth_testing": {
        "title": "Worth Testing",
        "examples": [
            "General chatbot",
            "Semi-creative writing",
            "Paraphrasing",
        ],
        "expected_speedup": "1.2-1.5x"
    },
    "unlikely_to_help": {
        "title": "Unlikely to Help Much",
        "examples": [
            "Highly creative fiction",
            "Poetry with unusual structure",
            "Brainstorming novel ideas",
            "Very short responses (< 20 tokens)",
        ],
        "expected_speedup": "1.0-1.2x"
    }
}

print("üéØ When to Use Speculative Decoding")
print("=" * 70)

for category, info in use_cases.items():
    print(f"\nüìå {info['title']} (Expected: {info['expected_speedup']})")
    for example in info['examples']:
        print(f"   ‚Ä¢ {example}")

In [None]:
# Trade-offs to consider

tradeoffs = {
    "pros": [
        "Up to 3x speedup for predictable outputs",
        "Mathematically identical output (same quality)",
        "Never slower than baseline (worst case: ~1x)",
        "No retraining of target model required",
        "Works with any autoregressive model",
    ],
    "cons": [
        "Requires additional GPU memory for draft model",
        "More complex deployment setup",
        "EAGLE models needed (not available for all architectures)",
        "Less benefit for creative/unpredictable tasks",
        "May increase TTFT slightly (draft overhead)",
    ],
    "requirements": [
        "Sufficient GPU memory for both models",
        "Compatible draft model (EAGLE or smaller version)",
        "SGLang, vLLM, or TensorRT-LLM with speculation support",
        "Output tokens > 20 for meaningful speedup",
    ]
}

print("‚öñÔ∏è Trade-offs of Speculative Decoding")
print("=" * 60)

print("\n‚úÖ PROS:")
for pro in tradeoffs["pros"]:
    print(f"   + {pro}")

print("\n‚ùå CONS:")
for con in tradeoffs["cons"]:
    print(f"   - {con}")

print("\nüìã REQUIREMENTS:")
for req in tradeoffs["requirements"]:
    print(f"   ‚Ä¢ {req}")

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Using Wrong Draft Model

```bash
# ‚ùå Wrong - Mismatched architectures
--model meta-llama/Llama-3.1-8B \
--speculative-draft-model yuhuili/EAGLE-Mistral-7B  # Wrong architecture!

# ‚úÖ Right - Matching architectures
--model meta-llama/Llama-3.1-8B \
--speculative-draft-model yuhuili/EAGLE-LLaMA3-Instruct-8B
```

**Why:** EAGLE models are trained specifically for target model architectures.

### Mistake 2: Too Many Draft Tokens

```bash
# ‚ùå Wrong - Too aggressive, low acceptance rate
--speculative-num-draft-tokens 16

# ‚úÖ Right - Balanced for good acceptance
--speculative-num-draft-tokens 5  # Default, usually optimal
```

**Why:** More tokens = lower probability all are accepted. Diminishing returns after ~5-8.

### Mistake 3: Expecting Speedup on Short Outputs

```python
# ‚ùå Won't see much speedup
response = generate("What is 2+2?", max_tokens=5)  # Too short!

# ‚úÖ Better for measuring speedup
response = generate("Explain calculus in detail.", max_tokens=200)
```

**Why:** Speculation overhead outweighs benefit for very short outputs.

---

## ‚úã Try It Yourself

### Exercise 1: Compare Speculation Lengths

Test different `--speculative-num-draft-tokens` values and measure acceptance rate.

In [None]:
# Exercise 1: Your code here
# TODO: Run SGLang with speculation lengths 3, 5, 8, 12
# TODO: Measure tokens/second for each
# TODO: Find the optimal value for your workload


### Exercise 2: Acceptance Rate Analysis

Create prompts that demonstrate high vs low acceptance rates.

In [None]:
# Exercise 2: Create test prompts
high_acceptance_prompts = [
    # TODO: Add 3 prompts likely to have high acceptance
    # Hint: Structured, predictable outputs
]

low_acceptance_prompts = [
    # TODO: Add 3 prompts likely to have low acceptance
    # Hint: Creative, unpredictable outputs
]

# TODO: Test both and compare speedups


---

## üéâ Checkpoint

You've learned:
- ‚úÖ How speculative decoding works (draft + verify)
- ‚úÖ How to set up SGLang with EAGLE speculative decoding
- ‚úÖ When speculation helps (predictable) vs doesn't help (creative)
- ‚úÖ Trade-offs and decision framework for deployment

---

## üöÄ Challenge (Optional)

**Build an Adaptive Speculation System**

Create a system that:
1. Analyzes the prompt type (code, chat, creative)
2. Adjusts speculation length based on observed acceptance rate
3. Disables speculation entirely for prompts unlikely to benefit
4. Logs and visualizes acceptance rates over time

---

## üìñ Further Reading

- [EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty](https://arxiv.org/abs/2401.15077)
- [EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees](https://arxiv.org/abs/2406.16858)
- [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104)
- [Medusa: Multiple Decode Heads for Parallel Decoding](https://sites.google.com/view/medusa-llm)
- [vLLM Speculative Decoding Guide](https://docs.vllm.ai/en/latest/models/spec_decode.html)

---

## üßπ Cleanup

In [None]:
# Cleanup
import gc
import subprocess

# Collect garbage
gc.collect()

# Check GPU memory usage
def get_gpu_memory():
    """Get current GPU memory usage."""
    try:
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv,noheader,nounits"],
            capture_output=True, text=True
        )
        if result.returncode == 0:
            used, total = map(int, result.stdout.strip().split(','))
            return used / 1024, total / 1024  # GB
    except:
        pass
    return None, None

used, total = get_gpu_memory()
if used and total:
    print(f"üìä GPU Memory: {used:.1f}GB / {total:.1f}GB ({used/total*100:.0f}% used)")

print("\n‚úÖ Cleanup complete!")
print("\nüí° To stop SGLang server:")
print("   pkill -f sglang")
print("\n   Or find and kill the process:")
print("   ps aux | grep sglang")
print("   kill <pid>")