# CSE 234 Programming Assignment 3: Speculative Decoding

## Setup

In [48]:
import os
import torch
import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Tuple, Dict, Optional

## Speculative Decoding

In [49]:
class SpeculativeDecoder:
    def __init__(self, target_model_name: str, draft_model_name: str, device: str = "cuda"):
        """
        Initialize the speculative decoder with target and draft models.

        Args:
            target_model_name: HuggingFace model ID for the larger target model.
            draft_model_name: HuggingFace model ID for the smaller draft model.
            device: Device to run models on ("cuda" or "cpu").
        """
        self.device = device
        self.cache_directory = './cache/'
        self.target_model, self.target_tokenizer = self.initialize_target_model(target_model_name)
        self.draft_model, self.draft_tokenizer = self.initialize_draft_model(draft_model_name)
        self.target_model.to(self.device)
        self.draft_model.to(self.device)

        # Ensure tokenizers are compatible
        assert self.target_tokenizer.vocab == self.draft_tokenizer.vocab, "Tokenizers must be compatible"

    def initialize_target_model(self, model_name: str):
        """Initialize the larger target model with caching enabled and proper pad token."""
        print(f"Loading target model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement target model initialization
        # 1. Set the pad token if it doesn't exist
        # 2. Load the model with appropriate settings for inference
        # 3. Enable any optimizations that might help with performance

        # Set pad token if not already set
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Load model with optimizations
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,  # Use FP16 for better performance
            device_map=self.device,         # Auto-assign to GPU/CPU
            use_cache=True,             # Enable KV-caching for faster inference
            cache_dir=self.cache_directory
        ).eval()  # Set to evaluation mode

        return model, tokenizer

    def initialize_draft_model(self, model_name: str):
        """
        Initialize a smaller, faster draft model with proper pad token.
        Uses lower precision and additional optimizations.
        """
        print(f"Loading draft model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement draft model initialization
        # 1. Set the pad token if it doesn't exist
        # 2. Load the model with appropriate settings for inference
        # 3. Enable any optimizations that might help with performance

        # Set pad token if not already set
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Load model with optimizations
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,  # Use FP16 for speed
            device_map=self.device,         # Auto-assign to GPU/CPU
            use_cache=True,             # Enable KV-caching for faster inference
            cache_dir=self.cache_directory
        ).eval()  # Set to evaluation mode

        return model, tokenizer

    def generate_draft_tokens(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                             num_speculative_tokens: int = 10) -> torch.Tensor:
        """
        Generate speculative tokens in one forward call using the draft model.

        Args:
            input_ids: Input token IDs (tensor of shape [1, seq_len]).
            attention_mask: Corresponding attention mask.
            num_speculative_tokens: Number of tokens to speculate.

        Returns:
            Tensor of shape [1, num_speculative_tokens] containing the draft tokens.
        """
        # TODO: Implement draft token generation
        # 1. Use the draft model to generate tokens
        # 2. Extract only the new tokens (not including the input)
        # 3. Return the newly generated tokens

        # pass

        with torch.no_grad():
            outputs = self.draft_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=input_ids.shape[1] + num_speculative_tokens,
                do_sample=False,  # Greedy decoding for higher acceptance rate
                pad_token_id=self.draft_tokenizer.pad_token_id,
                use_cache=True    # Leverage KV-caching
            )

        # Extract only the new tokens (excluding input)
        draft_tokens = outputs[:, input_ids.shape[1]:].to(torch.long)
        return draft_tokens

    def verify_tokens_vectorized(self, input_ids: torch.Tensor, draft_tokens: torch.Tensor,
                               attention_mask: torch.Tensor) -> Tuple[List[int], int]:
        """
        Vectorized verification: verify all draft tokens in one forward pass using the target model.

        Args:
            input_ids: The current input token IDs (shape [1, L]).
            draft_tokens: Draft tokens from the draft model (shape [1, k]).
            attention_mask: The current attention mask for input_ids.

        Returns:
            accepted_tokens: List of accepted token IDs.
            accepted_position: Index of the first rejected token (if all accepted, equals draft_tokens.shape[1]).
        """
        # TODO: Implement efficient verification of draft tokens
        # 1. Run target model on input_ids concatenated with draft_tokens
        # 2. Extract the logits for positions where draft tokens would be predicted
        # 3. Compare target model predictions with draft tokens
        # 4. Determine how many consecutive tokens were accepted before first mismatch

        # pass

        # Ensure draft_tokens are in long format
        draft_tokens = draft_tokens.to(torch.long)

        # Concatenate input_ids with draft_tokens for verification
        extended_input_ids = torch.cat([input_ids, draft_tokens], dim=1).to(torch.long)
        # extended_attention_mask = torch.cat(
        #     [attention_mask, torch.ones_like(draft_tokens, device=self.device)], dim=1
        # )

        extended_attention_mask = torch.ones_like(extended_input_ids, device=self.device)
        extended_attention_mask[:, :input_ids.shape[1]] = attention_mask

        with torch.no_grad():
            outputs = self.target_model(
                extended_input_ids,
                attention_mask=extended_attention_mask,
                use_cache=True  # Use KV-caching
            )
            logits = outputs.logits  # Shape: [1, L + k, vocab_size]

        # Extract logits for the positions where draft tokens are predicted
        draft_start_idx = input_ids.shape[1]
        draft_logits = logits[:, draft_start_idx - 1:-1, :]  # Predict tokens at draft positions

        # Get target model's predicted tokens
        target_predicted_tokens = draft_logits.argmax(dim=-1).to(torch.long)

        # Compare with draft tokens
        matches = (target_predicted_tokens == draft_tokens).squeeze(0)
        accepted_position = matches.tolist().index(False) if False in matches else draft_tokens.shape[1]
        accepted_tokens = draft_tokens[:, :accepted_position].tolist()[0]

        return accepted_tokens, accepted_position

    def speculative_decode(self, prompt: str, max_tokens: int = 100,
                          num_speculative_tokens: int = 20) -> str:
        """
        Main speculative decoding algorithm with vectorized verification.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate (excluding prompt).
            num_speculative_tokens: Number of tokens to speculate per iteration.

        Returns:
            Generated text.
        """
        # Tokenize prompt
        # inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=False)
        input_ids = inputs["input_ids"].to(self.device).to(torch.long)
        attention_mask = inputs["attention_mask"].to(self.device)
        prompt_length = input_ids.shape[1]


        # Initialize counters for performance tracking
        total_tokens_generated = prompt_length
        total_draft_tokens_proposed = 0
        total_draft_tokens_accepted = 0
        start_time = time.time()

        # TODO: Implement the core speculative decoding loop
        # 1. Generate draft tokens using the draft model
        # 2. Verify draft tokens using the target model
        # 3. Accept verified tokens and append to the sequence
        # 4. For rejected tokens or if all tokens are accepted, generate a new token with the target model
        # 5. Stop when max_tokens is reached or an EOS token is generated

        while total_tokens_generated < max_tokens:
            # Generate draft tokens
            if total_draft_tokens_proposed > 0:
              recent_acceptance = total_draft_tokens_accepted / total_draft_tokens_proposed
              num_speculative_tokens = max(15, int(20 * recent_acceptance))
            draft_tokens = self.generate_draft_tokens(input_ids, attention_mask, num_speculative_tokens)
            # draft_tokens = self.generate_draft_tokens(input_ids, attention_mask, 10)
            total_draft_tokens_proposed += draft_tokens.shape[1]

            # Verify draft tokens
            accepted_tokens, accepted_position = self.verify_tokens_vectorized(
                input_ids, draft_tokens, attention_mask
            )
            total_draft_tokens_accepted += len(accepted_tokens)

            # Append accepted tokens
            if accepted_tokens:
                accepted_tokens_tensor = torch.tensor(accepted_tokens, device=self.device).unsqueeze(0).to(torch.long)
                input_ids = torch.cat([input_ids, accepted_tokens_tensor], dim=1)
                attention_mask = torch.cat([attention_mask, torch.ones_like(accepted_tokens_tensor, device=self.device)], dim=1)
                total_tokens_generated += len(accepted_tokens)

            # If not all tokens accepted, generate one token with target model
            if accepted_position < draft_tokens.shape[1]:
                with torch.no_grad():
                    outputs = self.target_model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.shape[1] + 1,
                        do_sample=False,
                        pad_token_id=self.target_tokenizer.pad_token_id,
                        # use_cache=True
                    )
                new_token = outputs[:, -1:].to(torch.long)
                input_ids = torch.cat([input_ids, new_token], dim=1)
                attention_mask = torch.cat([attention_mask, torch.ones_like(new_token, device=self.device)], dim=1)
                total_tokens_generated += 1

            # Stop if EOS token is generated
            if input_ids[0, -1].item() == self.target_tokenizer.eos_token_id:
                break

        # Calculate performance metrics
        elapsed_time = time.time() - start_time
        acceptance_rate = total_draft_tokens_accepted / total_draft_tokens_proposed if total_draft_tokens_proposed > 0 else 0

        print(f"Generated {total_tokens_generated - prompt_length} tokens in {elapsed_time:.2f} seconds")
        print(f"Tokens per second: {(total_tokens_generated - prompt_length) / elapsed_time:.2f}")
        print(f"Draft token acceptance rate: {acceptance_rate:.2%}")

        return self.target_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def benchmark(self, prompt: str, max_tokens: int = 100,
                  num_runs: int = 3, compare_baseline: bool = True) -> Dict:
        """
        Benchmark the speculative decoder against baseline decoding.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate.
            num_runs: Number of benchmark runs.
            compare_baseline: Whether to compare with baseline (non-speculative) decoding.

        Returns:
            Dictionary with benchmark results.
        """
        results = {
            "speculative": {"times": [], "tokens_per_second": []},
            "baseline": {"times": [], "tokens_per_second": []} if compare_baseline else None
        }

        # Benchmark speculative decoding.
        for _ in range(num_runs):
            start_time = time.time()
            output = self.speculative_decode(prompt, max_tokens=max_tokens)
            elapsed = time.time() - start_time
            prompt_len = len(self.target_tokenizer(prompt)["input_ids"])
            output_tokens = len(self.target_tokenizer.encode(output)) - prompt_len
            tps = output_tokens / elapsed
            results["speculative"]["times"].append(elapsed)
            results["speculative"]["tokens_per_second"].append(tps)

        # Benchmark baseline decoding.
        if compare_baseline:
            for _ in range(num_runs):
                inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                start_time = time.time()
                with torch.no_grad():
                    output_ids = self.target_model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.shape[1] + max_tokens,
                        do_sample=False,
                        pad_token_id=self.target_tokenizer.pad_token_id
                    )
                elapsed = time.time() - start_time
                output_tokens = output_ids.shape[1] - input_ids.shape[1]
                tps = output_tokens / elapsed
                results["baseline"]["times"].append(elapsed)
                results["baseline"]["tokens_per_second"].append(tps)

        for method in results.keys():
            if results[method] is not None:
                avg_time = sum(results[method]["times"]) / num_runs
                avg_tps = sum(results[method]["tokens_per_second"]) / num_runs
                results[method]["avg_time"] = avg_time
                results[method]["avg_tokens_per_second"] = avg_tps

        if compare_baseline:
            speedup = results["baseline"]["avg_time"] / results["speculative"]["avg_time"]
            results["speedup"] = speedup
            results["latency_reduction"] = (1 - results["speculative"]["avg_time"] / results["baseline"]["avg_time"]) * 100
            # print(f"Speculative decoding speedup: {speedup:.2f}x")
            # print(f"Latency reduction: {results['latency_reduction']:.2f}%")

        return results

## Test

In [50]:
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: EleutherAI/pythia-1.4b-deduped
Loading draft model: EleutherAI/pythia-160m-deduped

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 107 tokens in 1.56 seconds
Tokens per second: 68.76
Draft token acceptance rate: 84.13%
Generated 107 tokens in 1.52 seconds
Tokens per second: 70.16
Draft token acceptance rate: 84.13%
Generated 107 tokens in 1.53 seconds
Tokens per second: 70.15
Draft token acceptance rate: 84.13%
Average speculative decoding time: 1.54 seconds
Average speculative tokens per second: 69.65
Average baseline decoding time: 2.24 seconds
Average baseline tokens per second: 44.86
Speedup: 1.46x
Latency reduction: 31.42%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 104 tokens in 1.39 seconds
Tokens per second: 74.92
Draft token acceptance rate: 89.57%
Generated 104 tokens in 1.39 seconds
Tokens per second: 74.94
Draft token acceptance rate: 89.57%
Generated 104 toke

## Bonus

In [51]:
# target_model_name = ...  # Larger target model
# draft_model_name = ...   # Smaller draft model

target_model_name = "EleutherAI/pythia-2.8b-deduped"  # Large model
draft_model_name = "EleutherAI/pythia-70m-deduped"   # Small model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: EleutherAI/pythia-2.8b-deduped
Loading draft model: EleutherAI/pythia-70m-deduped

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 103 tokens in 1.71 seconds
Tokens per second: 60.28
Draft token acceptance rate: 71.43%
Generated 103 tokens in 1.35 seconds
Tokens per second: 76.52
Draft token acceptance rate: 71.43%
Generated 103 tokens in 1.25 seconds
Tokens per second: 82.25
Draft token acceptance rate: 71.43%
Average speculative decoding time: 1.44 seconds
Average speculative tokens per second: 72.25
Average baseline decoding time: 2.87 seconds
Average baseline tokens per second: 34.93
Speedup: 1.99x
Latency reduction: 49.87%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 95 tokens in 2.94 seconds
Tokens per second: 32.31
Draft token acceptance rate: 31.92%
Generated 95 tokens in 2.59 seconds
Tokens per second: 36.62
Draft token acceptance rate: 31.92%
Generated 95 tokens i