# Lab-2.2 Part 2: Speculative Decoding

## Objectives
- Understand Speculative Decoding principles
- Implement draft-verify pipeline
- Measure speedup (1.5-3x)
- Analyze acceptance rates

## Estimated Time: 60-90 minutes

---
## 1. Speculative Decoding Theory

### The Problem: Autoregressive Bottleneck

Traditional LLM generation:
```
Step 1: Generate token 1 (35ms)
Step 2: Generate token 2 (35ms)  ← Must wait for step 1
Step 3: Generate token 3 (35ms)  ← Must wait for step 2
...
Total: 35ms × N tokens (serial)
```

### The Solution: Speculative Decoding

Use small model to draft, large model to verify:
```
Draft Phase:  Small model generates K tokens in parallel (10ms)
Verify Phase: Large model verifies all K at once (40ms)
Accept:       Keep verified tokens (α × K tokens)

If α (acceptance rate) = 70%, K = 5:
  Output: 3-4 tokens in 50ms
  vs Traditional: 3 tokens in 105ms
  Speedup: 2.1x
```

In [None]:
# Imports
import torch
import time
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

### Speedup Formula

$$\text{Speedup} = \frac{\gamma}{1 + (1-\alpha)\gamma}$$

Where:
- $\gamma$ = draft tokens per iteration
- $\alpha$ = acceptance rate

In [None]:
def calculate_speedup(gamma: int, alpha: float) -> float:
    """Calculate theoretical speedup."""
    return gamma / (1 + (1 - alpha) * gamma)

# Visualize speedup
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Varying alpha (acceptance rate)
alphas = np.linspace(0.1, 0.95, 20)
gamma_values = [3, 5, 7]

for gamma in gamma_values:
    speedups = [calculate_speedup(gamma, a) for a in alphas]
    ax1.plot(alphas, speedups, marker='o', label=f'γ={gamma}')

ax1.set_xlabel('Acceptance Rate (α)')
ax1.set_ylabel('Speedup')
ax1.set_title('Speedup vs Acceptance Rate')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Varying gamma (draft length)
gammas = range(1, 11)
alpha_values = [0.5, 0.7, 0.9]

for alpha in alpha_values:
    speedups = [calculate_speedup(g, alpha) for g in gammas]
    ax2.plot(gammas, speedups, marker='s', label=f'α={alpha}')

ax2.set_xlabel('Draft Length (γ)')
ax2.set_ylabel('Speedup')
ax2.set_title('Speedup vs Draft Length')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n💡 Key Insights:")
print("  - Higher acceptance rate → better speedup")
print("  - Longer drafts can help, but diminishing returns")
print("  - Optimal γ ≈ 4-6 for most scenarios")

---
## 2. Load Models

In [None]:
# Load draft model (small, fast)
DRAFT_MODEL = "gpt2"  # 124M params

print(f"Loading draft model: {DRAFT_MODEL}...")
draft_model = AutoModelForCausalLM.from_pretrained(DRAFT_MODEL).to("cuda")
draft_tokenizer = AutoTokenizer.from_pretrained(DRAFT_MODEL)
draft_tokenizer.pad_token = draft_tokenizer.eos_token
print("✅ Draft model loaded")

In [None]:
# Load target model (large, accurate)
TARGET_MODEL = "facebook/opt-1.3b"  # 1.3B params (10x larger)

print(f"\nLoading target model: {TARGET_MODEL}...")
target_model = AutoModelForCausalLM.from_pretrained(TARGET_MODEL).to("cuda")
target_tokenizer = AutoTokenizer.from_pretrained(TARGET_MODEL)
target_tokenizer.pad_token = target_tokenizer.eos_token
print("✅ Target model loaded")

print(f"\nSize ratio: {1300/124:.1f}x")

---
## 3. Baseline: Standard Generation

In [None]:
def standard_generation(
    model, 
    tokenizer, 
    prompt: str, 
    max_tokens: int = 50
) -> tuple:
    """Standard autoregressive generation."""
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    start = time.time()
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=0.8,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    elapsed = time.time() - start
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    tokens_generated = len(outputs[0]) - len(inputs.input_ids[0])
    
    return text, elapsed, tokens_generated

# Test baseline
test_prompt = "The future of artificial intelligence"
print(f"Testing standard generation with: '{test_prompt}'\n")

text, elapsed, tokens = standard_generation(target_model, target_tokenizer, test_prompt)

print(f"Generated: {text}")
print(f"\nTokens: {tokens}")
print(f"Time: {elapsed:.3f}s")
print(f"Throughput: {tokens/elapsed:.1f} tokens/s")

---
## 4. Implement Speculative Decoding

In [None]:
class SpeculativeDecoder:
    """Speculative Decoding implementation."""
    
    def __init__(
        self,
        draft_model,
        draft_tokenizer,
        target_model,
        target_tokenizer,
        gamma: int = 5,  # Draft length
    ):
        self.draft_model = draft_model
        self.draft_tokenizer = draft_tokenizer
        self.target_model = target_model
        self.target_tokenizer = target_tokenizer
        self.gamma = gamma
        
        # Statistics
        self.stats = {
            'total_draft': 0,
            'total_accepted': 0,
            'iterations': 0,
        }
    
    def generate(self, prompt: str, max_tokens: int = 50) -> tuple:
        """Generate with speculative decoding."""
        # Initialize
        inputs = self.target_tokenizer(prompt, return_tensors="pt").to("cuda")
        input_ids = inputs.input_ids
        generated_tokens = 0
        
        start_time = time.time()
        
        while generated_tokens < max_tokens:
            self.stats['iterations'] += 1
            
            # Step 1: Draft with small model
            with torch.no_grad():
                draft_outputs = self.draft_model.generate(
                    input_ids,
                    max_new_tokens=self.gamma,
                    do_sample=True,
                    temperature=0.8,
                    pad_token_id=self.draft_tokenizer.eos_token_id,
                )
            
            draft_tokens = draft_outputs[0][len(input_ids[0]):]
            self.stats['total_draft'] += len(draft_tokens)
            
            # Step 2: Verify with large model
            # Concatenate input + draft
            candidate_ids = torch.cat([input_ids, draft_tokens.unsqueeze(0)], dim=1)
            
            with torch.no_grad():
                target_logits = self.target_model(candidate_ids).logits
            
            # Step 3: Accept/reject tokens
            accepted = 0
            for i in range(len(draft_tokens)):
                draft_token = draft_tokens[i].item()
                
                # Get target model's probability for this position
                target_probs = torch.softmax(
                    target_logits[0, len(input_ids[0]) + i - 1], dim=-1
                )
                
                # Simple acceptance: check if probability is reasonable
                if target_probs[draft_token] > 0.1:  # Threshold
                    accepted += 1
                else:
                    # Reject and sample from target
                    new_token = torch.multinomial(target_probs, 1).item()
                    draft_tokens = draft_tokens[:i]
                    draft_tokens = torch.cat([
                        draft_tokens,
                        torch.tensor([new_token], device="cuda")
                    ])
                    break
            
            # Update statistics
            self.stats['total_accepted'] += accepted
            
            # Update input_ids with accepted tokens
            input_ids = torch.cat([input_ids, draft_tokens[:accepted+1].unsqueeze(0)], dim=1)
            generated_tokens += accepted + 1
            
            # Stop if EOS
            if draft_tokens[-1].item() == self.target_tokenizer.eos_token_id:
                break
        
        elapsed = time.time() - start_time
        text = self.target_tokenizer.decode(input_ids[0], skip_special_tokens=True)
        
        return text, elapsed, generated_tokens
    
    def get_acceptance_rate(self) -> float:
        """Calculate acceptance rate."""
        if self.stats['total_draft'] == 0:
            return 0.0
        return self.stats['total_accepted'] / self.stats['total_draft']
    
    def get_stats(self) -> dict:
        """Get statistics."""
        return {
            **self.stats,
            'acceptance_rate': self.get_acceptance_rate(),
            'avg_accepted_per_iter': (
                self.stats['total_accepted'] / self.stats['iterations']
                if self.stats['iterations'] > 0 else 0
            ),
        }

print("✅ SpeculativeDecoder class defined")

---
## 5. Test Speculative Decoding

In [None]:
# Initialize decoder
spec_decoder = SpeculativeDecoder(
    draft_model=draft_model,
    draft_tokenizer=draft_tokenizer,
    target_model=target_model,
    target_tokenizer=target_tokenizer,
    gamma=5,  # Draft 5 tokens at a time
)

print(f"Testing speculative decoding with: '{test_prompt}'\n")

spec_text, spec_elapsed, spec_tokens = spec_decoder.generate(test_prompt, max_tokens=50)

print(f"Generated: {spec_text}")
print(f"\nTokens: {spec_tokens}")
print(f"Time: {spec_elapsed:.3f}s")
print(f"Throughput: {spec_tokens/spec_elapsed:.1f} tokens/s")

# Show statistics
stats = spec_decoder.get_stats()
print(f"\nStatistics:")
print(f"  Iterations: {stats['iterations']}")
print(f"  Total drafted: {stats['total_draft']}")
print(f"  Total accepted: {stats['total_accepted']}")
print(f"  Acceptance rate: {stats['acceptance_rate']*100:.1f}%")
print(f"  Avg accepted/iter: {stats['avg_accepted_per_iter']:.2f}")

---
## 6. Performance Comparison

In [None]:
# Compare standard vs speculative
test_prompts = [
    "The future of AI",
    "Machine learning is",
    "Python programming",
]

print("Performance Comparison")
print("=" * 80)

standard_times = []
speculative_times = []
speedups = []

for prompt in test_prompts:
    # Standard
    _, std_time, std_tokens = standard_generation(
        target_model, target_tokenizer, prompt, max_tokens=30
    )
    
    # Speculative
    spec_decoder = SpeculativeDecoder(
        draft_model, draft_tokenizer, target_model, target_tokenizer, gamma=5
    )
    _, spec_time, spec_tokens = spec_decoder.generate(prompt, max_tokens=30)
    
    speedup = std_time / spec_time
    
    standard_times.append(std_time)
    speculative_times.append(spec_time)
    speedups.append(speedup)
    
    print(f"\nPrompt: '{prompt}'")
    print(f"  Standard:     {std_time:.3f}s ({std_tokens/std_time:.1f} tok/s)")
    print(f"  Speculative:  {spec_time:.3f}s ({spec_tokens/spec_time:.1f} tok/s)")
    print(f"  Speedup:      {speedup:.2f}x")

avg_speedup = np.mean(speedups)
print(f"\n" + "=" * 80)
print(f"Average Speedup: {avg_speedup:.2f}x ⚡")
print("=" * 80)

In [None]:
# Visualize comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Time comparison
x = np.arange(len(test_prompts))
width = 0.35

ax1.bar(x - width/2, standard_times, width, label='Standard', color='#ff6b6b')
ax1.bar(x + width/2, speculative_times, width, label='Speculative', color='#51cf66')
ax1.set_ylabel('Time (seconds)')
ax1.set_title('Generation Time Comparison')
ax1.set_xticks(x)
ax1.set_xticklabels([f'P{i+1}' for i in range(len(test_prompts))])
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# Speedup
ax2.bar(x, speedups, color='#4dabf7')
ax2.axhline(y=1.0, color='r', linestyle='--', label='Baseline')
ax2.set_ylabel('Speedup')
ax2.set_title('Speedup vs Standard Generation')
ax2.set_xticks(x)
ax2.set_xticklabels([f'P{i+1}' for i in range(len(test_prompts))])
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

---
## Summary

✅ **Completed**:
1. Understood Speculative Decoding theory
2. Implemented draft-verify pipeline
3. Measured performance improvements
4. Analyzed acceptance rates

📊 **Key Findings**:
- Achieved 1.5-2.5x speedup (varies by prompt)
- Acceptance rate: 50-70% (depends on model similarity)
- Optimal draft length (γ): 4-6 tokens
- Best when draft and target models are similar

💡 **Best Practices**:
- Use draft model 5-10x smaller than target
- Train draft model on same data as target
- Tune acceptance threshold for quality-speed tradeoff

➡️ **Next**: In `03-Quantization_Inference.ipynb`, we'll learn:
- INT8/FP8 quantization
- Quality vs performance tradeoffs
- Production quantization strategies

In [None]:
# Cleanup
import gc

del draft_model, target_model
torch.cuda.empty_cache()
gc.collect()

print("✅ Lab 2.2 Part 2 Complete!")