## Lab Summary

### What We Learned

**Decoding Methods Implemented (6 total)**:
1. Greedy - Fastest but repetitive
2. Beam Search - Better quality, still deterministic
3. Temperature - Simple randomness control
4. Top-k - Fixed vocabulary filtering
5. Nucleus (Top-p) - Dynamic filtering, modern standard
6. **Contrastive (NEW)** - Explicit degeneration prevention

### Key Insights

- **Quality-Diversity Tradeoff**: No method dominates all others
- **Task Matters**: Factual → Greedy/Beam, Creative → Nucleus/Contrastive
- **Contrastive Search**: Best for long creative generation (prevents repetition)
- **Nucleus Sampling**: Best general-purpose choice (GPT-4, ChatGPT, Claude use this)

### Recommended Settings (2025)

| Task Type | Method | Parameters |
|-----------|--------|------------|
| Factual Q&A | Greedy | T=0.1-0.3 |
| Translation | Beam Search | width=4-5 |
| Code | Greedy/Beam | T=0, width=3 |
| Dialogue | Nucleus | p=0.85-0.95, T=0.7-0.9 |
| Creative Writing | Nucleus | p=0.9-0.95, T=0.9-1.2 |
| Long Stories | Contrastive | α=0.5-0.7, k=4-6 |

### Next Steps

- Experiment with hybrid methods (Nucleus + Temperature)
- Try constrained decoding for structured outputs
- Explore learned decoding strategies (RLHF, DPO - Week 10)
- Implement speculative decoding for speed (advanced topic)

In [None]:
# Summary table of task-method fit
print("\nTASK-METHOD FIT SUMMARY")
print("=" * 90)
print(f"{'Task':<20} {'Method':<25} {'Diversity':<12} {'Repetition':<12}")
print("=" * 90)

for result in task_results:
    print(f"{result['task']:<20} {result['method']:<25} {result['diversity']:<12.3f} {result['repetition']:<12.1f}%")

print("=" * 90)
print("\nKEY INSIGHT: Different tasks need different decoding strategies!")
print("  - Factual tasks: Use deterministic methods (Greedy, Beam)")
print("  - Creative tasks: Use stochastic methods (Nucleus, Contrastive)")
print("  - Long generation: Use Contrastive to prevent degeneration")

In [None]:
# Define 4 tasks with optimal methods
tasks = [
    {
        'name': 'Factual QA',
        'prompt': 'The capital of France is',
        'optimal_method': 'Greedy or Low Temp',
        'generate': lambda p: greedy_decode(p, 20)
    },
    {
        'name': 'Creative Story',
        'prompt': 'In a distant galaxy',
        'optimal_method': 'Contrastive (long text)',
        'generate': lambda p: contrastive_search(p, alpha=0.6, k=4, max_length=60)
    },
    {
        'name': 'Dialogue',
        'prompt': 'How are you today?',
        'optimal_method': 'Nucleus (natural variation)',
        'generate': lambda p: top_p_sampling(p, p=0.9, temperature=0.8, max_length=40)
    },
    {
        'name': 'Code Completion',
        'prompt': 'def calculate_fibonacci(n):',
        'optimal_method': 'Beam Search (correctness)',
        'generate': lambda p: beam_search_decode(p, beam_size=3, max_length=50)
    }
]

print("TASK-SPECIFIC EXPERIMENTS")
print("=" * 80)

task_results = []
for task in tasks:
    print(f"\nTask: {task['name']}")
    print(f"Prompt: \"{task['prompt']}\"")
    print(f"Optimal Method: {task['optimal_method']}")
    print("-" * 80)
    
    output = task['generate'](task['prompt'])
    rep_rate = count_repeated_ngrams(output, n=3)[0]
    diversity = calculate_distinct_n(output, n=2)
    
    print(f"Output: {output}")
    print(f"\nMetrics: Diversity={diversity:.3f}, Repetition={rep_rate:.1f}%")
    print("=" * 80)
    
    task_results.append({
        'task': task['name'],
        'method': task['optimal_method'],
        'diversity': diversity,
        'repetition': rep_rate,
        'output': output[:100] + '...' if len(output) > 100 else output
    })

## Section 15: Task-Specific Experiments (NEW)

Test optimal method for 4 different NLP tasks:
1. **Factual QA**: Need correct answer
2. **Creative Story**: Need diversity
3. **Dialogue**: Need natural variation
4. **Code Completion**: Need syntactic correctness

In [None]:
# Visualize Quality-Diversity Pareto Frontier
fig, ax = plt.subplots(figsize=(10, 7))

method_colors = {
    'Greedy': '#E74C3C', 'Beam-3': '#E74C3C',
    'Temp=0.7': '#95A5A6', 'Top-k=40': '#95A5A6',
    'Nucleus=0.9': '#27AE60', 'Contrastive': '#27AE60'
}

for result in results_all:
    color = method_colors.get(result['method'], '#3333B2')
    ax.scatter(result['diversity'], result['quality'], 
              s=400, c=color, alpha=0.7, edgecolors='black', linewidths=2,
              label=result['method'])
    
    # Add method label
    ax.annotate(result['method'], 
               (result['diversity'], result['quality']),
               xytext=(5, 5), textcoords='offset points', fontsize=10, fontweight='bold')

# Optimal zone
from matplotlib.patches import Rectangle
optimal_zone = Rectangle((0.65, 75), 0.30, 20, 
                         linewidth=2, edgecolor='green', facecolor='green',
                         alpha=0.1, linestyle='--')
ax.add_patch(optimal_zone)
ax.text(0.80, 87, 'Optimal\nZone', ha='center', fontsize=11, 
       fontweight='bold', color='green')

ax.set_xlabel('Diversity (Distinct-2)', fontsize=12, fontweight='bold')
ax.set_ylabel('Quality (100 - Repetition %)', fontsize=12, fontweight='bold')
ax.set_title('All 6 Methods: Quality-Diversity Tradeoff', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 1)
ax.set_ylim(0, 100)

plt.tight_layout()
plt.show()

print("\nKEY OBSERVATIONS:")
print("- Greedy/Beam: High quality, low diversity (bottom-left)")
print("- Temperature: Variable quality-diversity")
print("- Nucleus/Contrastive: High quality AND high diversity (top-right - optimal!)")

In [None]:
def compute_perplexity_simple(text):
    """Approximate perplexity of generated text."""
    input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
    
    perplexity = torch.exp(loss).item()
    return perplexity

def full_evaluation(generate_fn, prompt, method_name, n_samples=5):
    """Complete quality-diversity evaluation."""
    outputs = [generate_fn(prompt, max_length=50) for _ in range(n_samples)]
    
    # Diversity metrics
    distinct_1 = np.mean([calculate_distinct_n(out, 1) for out in outputs])
    distinct_2 = np.mean([calculate_distinct_n(out, 2) for out in outputs])
    repetition = np.mean([count_repeated_ngrams(out, 3)[0] for out in outputs])
    
    # Quality metric
    perplexities = [compute_perplexity_simple(out) for out in outputs]
    avg_perplexity = np.mean(perplexities)
    
    return {
        'method': method_name,
        'diversity': distinct_2,
        'quality': 100 - repetition,  # Higher is better
        'perplexity': avg_perplexity,
        'distinct-1': distinct_1,
        'distinct-2': distinct_2,
        'repetition': repetition,
        'sample': outputs[0]
    }

# Evaluate all 6 methods
prompt_eval = "The future of artificial intelligence"

all_methods = [
    ('Greedy', lambda p, m: greedy_decode(p, m)),
    ('Beam-3', lambda p, m: beam_search_decode(p, 3, m)),
    ('Temp=0.7', lambda p, m: sample_with_temperature(p, 0.7, m)),
    ('Top-k=40', lambda p, m: top_k_sampling(p, 40, 0.8, m)),
    ('Nucleus=0.9', lambda p, m: top_p_sampling(p, 0.9, 0.8, m)),
    ('Contrastive', lambda p, m: contrastive_search(p, 0.6, 4, m)),
]

print("COMPREHENSIVE EVALUATION")
print("=" * 80)
print(f"Prompt: {prompt_eval}")
print(f"Metrics: Distinct-2 (diversity), Quality (100-repetition%), Perplexity\n")

results_all = []
for method_name, method_fn in all_methods:
    print(f"Evaluating {method_name}...")
    result = full_evaluation(method_fn, prompt_eval, method_name, n_samples=3)
    results_all.append(result)
    
    print(f"  Diversity: {result['diversity']:.3f} | Quality: {result['quality']:.1f} | PPL: {result['perplexity']:.1f}")
    print(f"  Sample: {result['sample'][:100]}...")
    print()

print("=" * 80)

## Section 14: Quality-Diversity Metrics Analysis (NEW)

Automated analysis of quality-diversity tradeoffs across all 6 methods.

We'll plot all methods on a 2D space to visualize the Pareto frontier.

In [None]:
# Compare Contrastive vs Nucleus on long generation
long_prompt = "Once upon a time in a faraway land"

print("LONG GENERATION TEST (100 tokens)")
print("=" * 80)
print(f"Prompt: {long_prompt}\n")

# Nucleus (may have repetition)
nucleus_long = top_p_sampling(long_prompt, p=0.9, temperature=0.8, max_length=100)
nucleus_rep = count_repeated_ngrams(nucleus_long, n=3)[0]

print(f"NUCLEUS (p=0.9, T=0.8):")
print(nucleus_long)
print(f"\nRepetition rate: {nucleus_rep:.1f}%")
print("\n" + "-" * 80 + "\n")

# Contrastive (should have less repetition)
contrastive_long = contrastive_search(long_prompt, alpha=0.6, k=4, max_length=100)
contrastive_rep = count_repeated_ngrams(contrastive_long, n=3)[0]

print(f"CONTRASTIVE (α=0.6, k=4):")
print(contrastive_long)
print(f"\nRepetition rate: {contrastive_rep:.1f}%")
print("\n" + "=" * 80)
print(f"\nConclusion: Contrastive reduced repetition from {nucleus_rep:.1f}% to {contrastive_rep:.1f}%!")

In [None]:
def contrastive_search(prompt, alpha=0.6, k=4, max_length=50):
    """
    Contrastive search decoding (Hugging Face 2024).
    
    Args:
        alpha: Penalty weight (0=greedy, 0.6=balanced, 1.0=max diversity)
        k: Number of candidates to consider
    
    Score = (1-alpha)*P(token) - alpha*max_similarity_to_context
    """
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for step in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids, output_hidden_states=True)
            logits = outputs.logits
            hidden_states = outputs.hidden_states[-1]  # Last layer
        
        next_token_logits = logits[:, -1, :]
        log_probs = torch.log_softmax(next_token_logits, dim=-1)
        
        # Get top-k candidates by probability
        top_k_probs, top_k_ids = torch.topk(log_probs, k)
        
        # Get embeddings for context tokens
        context_embeddings = hidden_states[0, :-1, :]  # All except last
        
        # Get embeddings for candidate tokens
        with torch.no_grad():
            # Create dummy sequences with each candidate
            candidate_scores = []
            
            for i in range(k):
                candidate_id = top_k_ids[0, i]
                prob_score = top_k_probs[0, i].item()
                
                # For simplicity, use embedding from model
                candidate_embedding = model.transformer.wte(candidate_id.unsqueeze(0))
                
                # Compute similarity to all context tokens
                similarities = torch.cosine_similarity(
                    candidate_embedding.expand(context_embeddings.size(0), -1),
                    context_embeddings,
                    dim=-1
                )
                
                # Max similarity (worst case)
                max_sim = similarities.max().item() if len(similarities) > 0 else 0.0
                
                # Contrastive score
                score = (1 - alpha) * prob_score - alpha * max_sim
                candidate_scores.append((score, candidate_id))
        
        # Select candidate with highest contrastive score
        best_score, next_token = max(candidate_scores, key=lambda x: x[0])
        
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
        
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Test contrastive search
prompt = "The weather is"
print("Testing Contrastive Search...")
print(f"Prompt: {prompt}\n")

for alpha in [0.0, 0.4, 0.6, 0.8]:
    output = contrastive_search(prompt, alpha=alpha, k=4, max_length=40)
    print(f"α={alpha}: {output}")
    print("-" * 70)

## Section 13: Contrastive Search (NEW 2025)

Explicit degeneration prevention through similarity penalty.

**Key Innovation**: Penalize tokens similar to recent context to avoid repetition in long generations.

# Week 9 Lab: Decoding Strategies

## Learning Objectives
- Implement greedy decoding, beam search, and sampling methods
- Compare decoding strategies on real language models
- Understand temperature, top-k, and top-p parameters
- Evaluate trade-offs between accuracy and creativity

## Prerequisites
```bash
pip install transformers torch numpy matplotlib
```

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import warnings
warnings.filterwarnings('ignore')

print('Week 9: Decoding Strategies Lab')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Section 1: Load Pre-trained Model

We'll use GPT-2 for all experiments.

In [None]:
# Load model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
model.eval()

print(f"Loaded {model_name}")
print(f"Vocabulary size: {len(tokenizer)}")

In [None]:
# Define test prompts
test_prompts = [
    "The weather is",
    "Once upon a time",
    "In a shocking discovery",
    "The capital of France is",
    "def factorial(n):"
]

print("Test prompts:")
for i, prompt in enumerate(test_prompts, 1):
    print(f"{i}. \"{prompt}\"")

## Section 2: Greedy Decoding

Always pick the highest probability token.

In [None]:
def greedy_decode(prompt, max_length=50):
    """Implement greedy decoding from scratch."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits
        
        # Get last token logits
        next_token_logits = logits[:, -1, :]
        
        # Greedy: pick argmax
        next_token = torch.argmax(next_token_logits, dim=-1)
        
        # Append to sequence
        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
        
        # Stop at end-of-sequence
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Test greedy decoding
prompt = test_prompts[0]
output = greedy_decode(prompt, max_length=30)
print(f"Prompt: {prompt}")
print(f"Output: {output}")

In [None]:
# Test greedy on all prompts
print("=" * 60)
print("GREEDY DECODING RESULTS")
print("=" * 60)

for prompt in test_prompts:
    output = greedy_decode(prompt, max_length=30)
    print(f"\nPrompt: {prompt}")
    print(f"Output: {output}")
    print("-" * 60)

### Exercise 1: Analyze Repetition

Count how many times greedy decoding produces repeated n-grams.

In [None]:
def count_repeated_ngrams(text, n=3):
    """Count repeated n-grams in text."""
    words = text.split()
    ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
    
    # Count unique vs total
    unique_ngrams = len(set(ngrams))
    total_ngrams = len(ngrams)
    
    repetition_rate = (1 - unique_ngrams / max(total_ngrams, 1)) * 100
    
    return repetition_rate, unique_ngrams, total_ngrams

# Test on greedy output
prompt = "The weather is"
output = greedy_decode(prompt, max_length=50)
rep_rate, unique, total = count_repeated_ngrams(output, n=3)

print(f"Text: {output}")
print(f"\n3-gram repetition rate: {rep_rate:.1f}%")
print(f"Unique 3-grams: {unique}/{total}")

## Section 3: Beam Search

Keep track of multiple hypotheses.

In [None]:
def beam_search_decode(prompt, beam_size=3, max_length=50):
    """Implement beam search decoding."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    # Initialize beam: (sequence, score)
    beams = [(input_ids, 0.0)]
    
    for _ in range(max_length):
        all_candidates = []
        
        for seq, score in beams:
            with torch.no_grad():
                outputs = model(seq)
                logits = outputs.logits
            
            next_token_logits = logits[:, -1, :]
            log_probs = torch.log_softmax(next_token_logits, dim=-1)
            
            # Get top-k candidates
            top_k_probs, top_k_ids = torch.topk(log_probs, beam_size)
            
            for i in range(beam_size):
                next_token = top_k_ids[0, i].unsqueeze(0).unsqueeze(0)
                next_score = score + top_k_probs[0, i].item()
                next_seq = torch.cat([seq, next_token], dim=-1)
                
                all_candidates.append((next_seq, next_score))
        
        # Keep top beam_size
        ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
        beams = ordered[:beam_size]
        
        # Check if best beam ended
        if beams[0][0][0, -1].item() == tokenizer.eos_token_id:
            break
    
    # Return best sequence
    best_seq = beams[0][0]
    return tokenizer.decode(best_seq[0], skip_special_tokens=True)

# Test beam search
prompt = test_prompts[0]
output_beam = beam_search_decode(prompt, beam_size=3, max_length=30)
print(f"Prompt: {prompt}")
print(f"Beam Search (size=3): {output_beam}")

In [None]:
# Compare greedy vs beam search
prompt = "Once upon a time"

greedy_out = greedy_decode(prompt, max_length=40)
beam3_out = beam_search_decode(prompt, beam_size=3, max_length=40)
beam5_out = beam_search_decode(prompt, beam_size=5, max_length=40)

print("Comparison: Greedy vs Beam Search")
print("=" * 70)
print(f"Prompt: {prompt}\n")
print(f"Greedy:     {greedy_out}\n")
print(f"Beam-3:     {beam3_out}\n")
print(f"Beam-5:     {beam5_out}")

### Exercise 2: Beam Size Analysis

Test different beam sizes and measure diversity.

In [None]:
# Test beam sizes 1, 3, 5, 10
beam_sizes = [1, 3, 5, 10]
prompt = "The future of AI is"

results = []
for beam_size in beam_sizes:
    output = beam_search_decode(prompt, beam_size=beam_size, max_length=40)
    rep_rate, _, _ = count_repeated_ngrams(output, n=3)
    results.append((beam_size, output, rep_rate))

print("Beam Size Analysis")
print("=" * 70)
for beam_size, output, rep_rate in results:
    print(f"\nBeam={beam_size} | Repetition: {rep_rate:.1f}%")
    print(f"Output: {output}")
    print("-" * 70)

## Section 4: Temperature Sampling

Control randomness with temperature parameter.

In [None]:
def sample_with_temperature(prompt, temperature=1.0, max_length=50):
    """Generate text with temperature sampling."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits
        
        next_token_logits = logits[:, -1, :]
        
        # Apply temperature
        scaled_logits = next_token_logits / temperature
        probs = torch.softmax(scaled_logits, dim=-1)
        
        # Sample from distribution
        next_token = torch.multinomial(probs, num_samples=1)
        
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Test different temperatures
prompt = "The weather is"
temperatures = [0.1, 0.5, 0.7, 1.0, 1.5]

print("Temperature Comparison")
print("=" * 70)
print(f"Prompt: {prompt}\n")

for temp in temperatures:
    output = sample_with_temperature(prompt, temperature=temp, max_length=30)
    print(f"T={temp:<4}: {output}")
    print("-" * 70)

In [None]:
# Multiple samples at same temperature
prompt = "Once upon a time"
temperature = 0.8
n_samples = 5

print(f"Multiple samples with T={temperature}")
print("=" * 70)
print(f"Prompt: {prompt}\n")

for i in range(n_samples):
    output = sample_with_temperature(prompt, temperature=temperature, max_length=30)
    print(f"Sample {i+1}: {output}")

## Section 5: Top-k Sampling

Sample from top-k most likely tokens.

In [None]:
def top_k_sampling(prompt, k=50, temperature=1.0, max_length=50):
    """Generate text with top-k sampling."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits
        
        next_token_logits = logits[:, -1, :] / temperature
        
        # Get top-k
        top_k_logits, top_k_indices = torch.topk(next_token_logits, k)
        
        # Sample from top-k
        probs = torch.softmax(top_k_logits, dim=-1)
        next_token_idx = torch.multinomial(probs, num_samples=1)
        next_token = top_k_indices[0, next_token_idx]
        
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
        
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Test different k values
prompt = "In a shocking discovery"
k_values = [5, 20, 50, 100]

print("Top-k Comparison")
print("=" * 70)
print(f"Prompt: {prompt}\n")

for k in k_values:
    output = top_k_sampling(prompt, k=k, temperature=0.8, max_length=30)
    print(f"k={k:<4}: {output}")
    print("-" * 70)

## Section 6: Top-p (Nucleus) Sampling

Sample from nucleus with cumulative probability p.

In [None]:
def top_p_sampling(prompt, p=0.9, temperature=1.0, max_length=50):
    """Generate text with top-p (nucleus) sampling."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits
        
        next_token_logits = logits[:, -1, :] / temperature
        probs = torch.softmax(next_token_logits, dim=-1)
        
        # Sort probabilities
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        
        # Find cutoff for nucleus
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
        cutoff_idx = torch.where(cumsum_probs >= p)[0][0] + 1
        
        # Keep only nucleus
        nucleus_probs = sorted_probs[:, :cutoff_idx]
        nucleus_indices = sorted_indices[:, :cutoff_idx]
        
        # Renormalize and sample
        nucleus_probs = nucleus_probs / nucleus_probs.sum()
        next_token_idx = torch.multinomial(nucleus_probs, num_samples=1)
        next_token = nucleus_indices[0, next_token_idx]
        
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
        
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Test different p values
prompt = "The future of technology"
p_values = [0.5, 0.7, 0.9, 0.95, 0.99]

print("Top-p Comparison")
print("=" * 70)
print(f"Prompt: {prompt}\n")

for p in p_values:
    output = top_p_sampling(prompt, p=p, temperature=0.8, max_length=30)
    print(f"p={p:<5}: {output}")
    print("-" * 70)

## Section 7: Comprehensive Comparison

Compare all methods on the same prompts.

In [None]:
def compare_all_methods(prompt, max_length=40):
    """Compare all decoding strategies."""
    results = {}
    
    # Greedy
    results['Greedy'] = greedy_decode(prompt, max_length)
    
    # Beam search
    results['Beam-3'] = beam_search_decode(prompt, beam_size=3, max_length=max_length)
    
    # Temperature sampling
    results['T=0.5'] = sample_with_temperature(prompt, temperature=0.5, max_length=max_length)
    results['T=1.0'] = sample_with_temperature(prompt, temperature=1.0, max_length=max_length)
    
    # Top-k
    results['Top-k=40'] = top_k_sampling(prompt, k=40, temperature=0.8, max_length=max_length)
    
    # Top-p
    results['Top-p=0.9'] = top_p_sampling(prompt, p=0.9, temperature=0.8, max_length=max_length)
    
    return results

# Test on multiple prompts
for prompt in test_prompts:
    print("=" * 80)
    print(f"Prompt: \"{prompt}\"")
    print("=" * 80)
    
    results = compare_all_methods(prompt, max_length=35)
    
    for method, output in results.items():
        print(f"\n{method:<15}: {output}")
    
    print("\n")

## Section 8: Diversity and Quality Metrics

Calculate distinct-n and repetition rate.

In [None]:
def calculate_distinct_n(text, n=2):
    """Calculate distinct-n metric (diversity)."""
    words = text.split()
    ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
    
    if len(ngrams) == 0:
        return 0.0
    
    return len(set(ngrams)) / len(ngrams)

def evaluate_decoding_method(generate_fn, prompt, n_samples=5, max_length=40):
    """Evaluate a decoding method."""
    outputs = [generate_fn(prompt, max_length=max_length) for _ in range(n_samples)]
    
    # Calculate metrics
    distinct_1 = np.mean([calculate_distinct_n(out, n=1) for out in outputs])
    distinct_2 = np.mean([calculate_distinct_n(out, n=2) for out in outputs])
    repetition = np.mean([count_repeated_ngrams(out, n=3)[0] for out in outputs])
    
    return {
        'distinct-1': distinct_1,
        'distinct-2': distinct_2,
        'repetition': repetition,
        'samples': outputs
    }

# Compare methods
prompt = "The weather is"

methods = [
    ('Greedy', lambda p, m: greedy_decode(p, m)),
    ('T=0.7', lambda p, m: sample_with_temperature(p, 0.7, m)),
    ('T=1.0', lambda p, m: sample_with_temperature(p, 1.0, m)),
    ('Top-p=0.9', lambda p, m: top_p_sampling(p, 0.9, 0.8, m)),
]

print("Method Evaluation")
print("=" * 70)
print(f"Prompt: {prompt}\n")

for method_name, method_fn in methods:
    results = evaluate_decoding_method(method_fn, prompt, n_samples=5, max_length=30)
    print(f"\n{method_name}:")
    print(f"  Distinct-1: {results['distinct-1']:.3f}")
    print(f"  Distinct-2: {results['distinct-2']:.3f}")
    print(f"  Repetition: {results['repetition']:.1f}%")
    print(f"  Sample: {results['samples'][0]}")

## Section 9: Visualization

Plot diversity vs quality trade-off.

In [None]:
# Collect data for all methods
methods_data = []
prompt = "The future of AI"

test_methods = [
    ('Greedy', lambda p, m: greedy_decode(p, m)),
    ('Beam-3', lambda p, m: beam_search_decode(p, 3, m)),
    ('T=0.5', lambda p, m: sample_with_temperature(p, 0.5, m)),
    ('T=0.8', lambda p, m: sample_with_temperature(p, 0.8, m)),
    ('T=1.2', lambda p, m: sample_with_temperature(p, 1.2, m)),
    ('Top-k=40', lambda p, m: top_k_sampling(p, 40, 0.8, m)),
    ('Top-p=0.9', lambda p, m: top_p_sampling(p, 0.9, 0.8, m)),
]

for name, fn in test_methods:
    metrics = evaluate_decoding_method(fn, prompt, n_samples=3, max_length=30)
    methods_data.append({
        'name': name,
        'diversity': metrics['distinct-2'],
        'repetition': metrics['repetition']
    })

# Plot
fig, ax = plt.subplots(figsize=(10, 6))

for data in methods_data:
    ax.scatter(data['diversity'], 100 - data['repetition'], s=200, alpha=0.7)
    ax.annotate(data['name'], (data['diversity'], 100 - data['repetition']),
               xytext=(5, 5), textcoords='offset points', fontsize=10)

ax.set_xlabel('Diversity (Distinct-2)', fontsize=12, fontweight='bold')
ax.set_ylabel('Quality (100 - Repetition %)', fontsize=12, fontweight='bold')
ax.set_title('Decoding Strategy Trade-off: Diversity vs Quality', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Section 10: Real-World Task - Story Writing

Compare methods on creative writing.

In [None]:
# Story writing task
story_prompt = "In a distant galaxy"
story_length = 80

print("CREATIVE WRITING COMPARISON")
print("=" * 80)
print(f"Prompt: {story_prompt}\n")

# Greedy (boring)
greedy_story = greedy_decode(story_prompt, max_length=story_length)
print("GREEDY (deterministic, repetitive):")
print(greedy_story)
print("\n" + "-" * 80 + "\n")

# Sampling (creative)
sampling_story = top_p_sampling(story_prompt, p=0.95, temperature=0.9, max_length=story_length)
print("TOP-P SAMPLING (creative, diverse):")
print(sampling_story)
print("\n" + "=" * 80)

## Section 11: Real-World Task - Factual QA

Compare methods on factual questions.

In [None]:
# Factual QA task
factual_prompts = [
    "The capital of France is",
    "Water boils at",
    "The speed of light is",
]

print("FACTUAL QA COMPARISON")
print("=" * 80)

for qa_prompt in factual_prompts:
    greedy_ans = greedy_decode(qa_prompt, max_length=15)
    sampling_ans = sample_with_temperature(qa_prompt, temperature=1.0, max_length=15)
    
    print(f"\nQ: {qa_prompt}")
    print(f"Greedy:   {greedy_ans}")
    print(f"Sampling: {sampling_ans}")
    print("-" * 80)

print("\nConclusion: Greedy is better for factual QA!")

## Section 12: Summary and Recommendations

What did we learn?

In [None]:
print("""
KEY FINDINGS:
============

1. GREEDY DECODING:
   ✓ Fast and deterministic
   ✗ Repetitive and boring
   Best for: Factual QA, translation

2. BEAM SEARCH:
   ✓ Better than greedy
   ✗ Still deterministic
   Best for: Code generation, structured output

3. TEMPERATURE SAMPLING:
   ✓ Controls creativity
   ✓ Easy to understand
   Best for: Most creative tasks

4. TOP-K SAMPLING:
   ✓ Prevents tail sampling
   ✗ Fixed vocabulary size
   Best for: Moderate creativity

5. TOP-P SAMPLING:
   ✓ Adapts to distribution
   ✓ Most robust
   Best for: General use (combine with temperature)

RECOMMENDATIONS:
===============
- Factual QA:      Greedy or T=0.1-0.3
- Translation:     Beam-4 or T=0.3
- Code:            Beam-5 + T=0.2
- Dialogue:        T=0.7-0.9, p=0.9
- Creative:        T=0.9-1.2, p=0.95

DEFAULT SETTINGS:
================
If unsure, use: T=0.7, p=0.9
""")