# Lab 3.3.4: Speculative Decoding - Solutions

This notebook provides solutions for the exercises in the speculative decoding notebook.

## Exercise 1: Compare Speculation Lengths

Test different `--speculative-num-draft-tokens` values (3, 5, 8, 12) and measure tokens/second.

In [None]:
# Solution: Testing different speculation lengths

# You would run SGLang with different configurations and measure performance:
# python -m sglang.launch_server ... --speculative-num-draft-tokens 3
# python -m sglang.launch_server ... --speculative-num-draft-tokens 5
# etc.

# Expected results for predictable prompts (e.g., "Count from 1 to 50"):
speculation_length_results = {
    "no_speculation": {
        "tokens_per_sec": 45,
        "acceptance_rate": None,
        "speedup": 1.0
    },
    "draft_tokens_3": {
        "tokens_per_sec": 85,
        "acceptance_rate": 0.82,
        "speedup": 1.9
    },
    "draft_tokens_5": {
        "tokens_per_sec": 110,
        "acceptance_rate": 0.78,
        "speedup": 2.4
    },
    "draft_tokens_8": {
        "tokens_per_sec": 105,
        "acceptance_rate": 0.72,
        "speedup": 2.3
    },
    "draft_tokens_12": {
        "tokens_per_sec": 95,
        "acceptance_rate": 0.65,
        "speedup": 2.1
    }
}

print("Speculation Length Comparison (Predictable Prompt):")
print("=" * 70)
print(f"{'Configuration':<20} {'Tokens/sec':<15} {'Accept Rate':<15} {'Speedup':<10}")
print("-" * 60)

for config, result in speculation_length_results.items():
    accept = f"{result['acceptance_rate']:.0%}" if result['acceptance_rate'] else "N/A"
    print(f"{config:<20} {result['tokens_per_sec']:<15} {accept:<15} {result['speedup']:.1f}x")

print("\nConclusion:")
print("- draft_tokens=5 is optimal for this workload (highest speedup)")
print("- Higher values (8, 12) have lower acceptance rates")
print("- The 'sweet spot' depends on your prompt type and draft model quality")

## Exercise 2: Acceptance Rate Analysis

Create prompts that demonstrate high vs low acceptance rates.

In [None]:
# Solution: High vs Low Acceptance Rate Prompts

# HIGH ACCEPTANCE prompts (predictable outputs)
high_acceptance_prompts = [
    "Count from 1 to 30.",
    "List all months of the year.",
    "Write the first 10 lines of 'Twinkle Twinkle Little Star'.",
]

# LOW ACCEPTANCE prompts (creative/unpredictable outputs)
low_acceptance_prompts = [
    "Invent a completely new word and give it a creative definition.",
    "Write an absurdist poem about quantum physics and breakfast.",
    "Create a unique fictional language greeting with pronunciation guide.",
]

# Expected results
acceptance_comparison = {
    "High Acceptance (Predictable)": {
        "prompts": high_acceptance_prompts,
        "expected_acceptance_rate": 0.85,
        "expected_speedup": 2.4,
        "why": "Draft model easily predicts next tokens (sequences, patterns)"
    },
    "Low Acceptance (Creative)": {
        "prompts": low_acceptance_prompts,
        "expected_acceptance_rate": 0.35,
        "expected_speedup": 1.1,
        "why": "Creative output is hard to predict; most drafts rejected"
    }
}

print("Acceptance Rate Analysis:")
print("=" * 70)

for category, data in acceptance_comparison.items():
    print(f"\n{category}")
    print("-" * 40)
    print(f"Expected Acceptance Rate: {data['expected_acceptance_rate']:.0%}")
    print(f"Expected Speedup: {data['expected_speedup']:.1f}x")
    print(f"Why: {data['why']}")
    print("\nExample prompts:")
    for prompt in data['prompts']:
        print(f"  - {prompt}")

In [None]:
# Implementation for measuring acceptance rate
import requests
import time
import json

SGLANG_URL = "http://localhost:30000"

def measure_speedup(prompt: str, max_tokens: int = 100) -> dict:
    """
    Measure tokens/second for a prompt.
    
    Note: SGLang provides acceptance rate in response metadata
    when using speculative decoding.
    """
    start = time.perf_counter()
    tokens_generated = 0
    
    try:
        response = requests.post(
            f"{SGLANG_URL}/v1/chat/completions",
            json={
                "model": "default",
                "messages": [{"role": "user", "content": prompt}],
                "max_tokens": max_tokens,
                "stream": True
            },
            stream=True,
            timeout=60
        )
        
        for line in response.iter_lines():
            if line:
                line_str = line.decode()
                if line_str.startswith("data: ") and "[DONE]" not in line_str:
                    try:
                        chunk = json.loads(line_str[6:])
                        if chunk.get("choices", [{}])[0].get("delta", {}).get("content"):
                            tokens_generated += 1
                    except:
                        pass
        
        elapsed = time.perf_counter() - start
        return {
            "prompt": prompt[:40] + "...",
            "tokens": tokens_generated,
            "tokens_per_sec": tokens_generated / elapsed if elapsed > 0 else 0
        }
    except Exception as e:
        return {"error": str(e)}

print("Test Function Ready")
print("Run measure_speedup(prompt) for each prompt type to compare speedups")

## Key Takeaways

1. **Optimal speculation length** is typically 5-8 tokens
2. **Predictable outputs** benefit most (2-3x speedup)
3. **Creative outputs** see minimal benefit but never slow down
4. **Test your workload** to find the right configuration