# CSE 234 Programming Assignment 3: Speculative Decoding

## Setup

In [8]:
import os
import torch
import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Tuple, Dict, Optional

## Speculative Decoding

In [9]:
class SpeculativeDecoder:
    def __init__(self, target_model_name: str, draft_model_name: str, device: str = "cuda", precision_: torch.dtype = torch.float32, cache: bool=False):
        """
        Initialize the speculative decoder with target and draft models.

        Args:
            target_model_name: HuggingFace model ID for the larger target model.
            draft_model_name: HuggingFace model ID for the smaller draft model.
            device: Device to run models on ("cuda" or "cpu").
        """
        self.device = device
        self.precision_ = precision_
        self.target_model, self.target_tokenizer = self.initialize_target_model(target_model_name)
        self.draft_model, self.draft_tokenizer = self.initialize_draft_model(draft_model_name)
        self.cache = cache

        # Ensure tokenizers are compatible
        assert self.target_tokenizer.vocab == self.draft_tokenizer.vocab, "Tokenizers must be compatible"

    def initialize_target_model(self, model_name: str):
        """Initialize the larger target model with caching enabled and proper pad token."""
        print(f"Loading target model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement target model initialization
        # 1. Set the pad token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # 2. Load the model with appropriate settings for inference
        # 3. Enable any optimizations that might help with performance
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=self.precision_,
            device_map="auto"
        )

        return model, tokenizer

    def initialize_draft_model(self, model_name: str):
        """
        Initialize a smaller, faster draft model with proper pad token.
        Uses lower precision and additional optimizations.
        """
        print(f"Loading draft model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement draft model initialization
        # 1. Set the pad token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # 2. Load the model with appropriate settings for inference
        # 3. Enable any optimizations that might help with performance
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=self.precision_,
            device_map="auto"
        )

        return model, tokenizer

    def generate_draft_tokens(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                             num_speculative_tokens: int = 10) -> torch.Tensor:
        """
        Generate speculative tokens in one forward call using the draft model.

        Args:
            input_ids: Input token IDs (tensor of shape [1, seq_len]).
            attention_mask: Corresponding attention mask.
            num_speculative_tokens: Number of tokens to speculate.

        Returns:
            Tensor of shape [1, num_speculative_tokens] containing the draft tokens.
        """
        # TODO: Implement draft token generation
        # 1. Use the draft model to generate tokens
        with torch.no_grad():
            output = self.draft_model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=num_speculative_tokens,
                eos_token_id=None,
                do_sample=False
            )

        # 2. Extract only the new tokens (not including the input)
        # 3. Return the newly generated tokens
        output = output[:, input_ids.shape[1]:]
        return output

    def verify_tokens_vectorized(self, input_ids: torch.Tensor, draft_tokens: torch.Tensor,
                               attention_mask: torch.Tensor) -> Tuple[List[int], int]:
        """
        Vectorized verification: verify all draft tokens in one forward pass using the target model.

        Args:
            input_ids: The current input token IDs (shape [1, L]).
            draft_tokens: Draft tokens from the draft model (shape [1, k]).
            attention_mask: The current attention mask for input_ids.

        Returns:
            accepted_tokens: List of accepted token IDs.
            accepted_position: Index of the first rejected token (if all accepted, equals draft_tokens.shape[1]).
        """
        # TODO: Implement efficient verification of draft tokens
        # 1. Run target model on input_ids concatenated with draft_tokens
        og_shape_ip = input_ids.shape[1]
        input_ids = torch.cat([input_ids, draft_tokens], dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(draft_tokens)], dim=1)
        with torch.no_grad():
            output = self.target_model(
                input_ids,
                attention_mask=attention_mask,
                use_cache=self.cache,
                do_sample=False
            )
            logits = output.logits[:, og_shape_ip-1:-1, :] # logits[:, og_shape_ip:, :] -> [d1, d2, ..., d10, next]

        # 2. Extract the logits for positions where draft tokens would be predicted
        target_draft_tokens = logits.argmax(dim=-1) # next likely tokens in 1FP

        # 3. Compare target model predictions with draft tokens
        matches = (target_draft_tokens == draft_tokens).int()

        # 4. Determine how many consecutive tokens were accepted before first mismatch
        accepted_position = draft_tokens.shape[1] if matches.all() else torch.where(matches == 0)[1][0].item()
        return draft_tokens[:, :accepted_position].tolist()[0], accepted_position

    def speculative_decode(self, prompt: str, max_tokens: int = 100,
                          num_speculative_tokens: int = 15) -> str:
        """
        Main speculative decoding algorithm with vectorized verification.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate (excluding prompt).
            num_speculative_tokens: Number of tokens to speculate per iteration.

        Returns:
            Generated text.
        """
        # Tokenize prompt
        inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        prompt_length = input_ids.shape[1]
        eos_token_id = self.target_tokenizer.eos_token_id

        # Initialize counters for performance tracking
        total_tokens_generated = prompt_length
        total_draft_tokens_proposed = 0
        total_draft_tokens_accepted = 0
        start_time = time.time()

        # TODO: Implement the core speculative decoding loop
        while total_tokens_generated < max_tokens:
            # 1. Generate draft tokens using the draft model
            draft_tokens = self.generate_draft_tokens(input_ids, attention_mask)
            total_draft_tokens_proposed += draft_tokens.shape[1]

            # 2. Verify draft tokens using the target model
            accepted_tokens, accepted_position = self.verify_tokens_vectorized(input_ids, draft_tokens, attention_mask)
            total_draft_tokens_accepted += len(accepted_tokens)

            # 3. Accept verified tokens and append to the sequence
            if accepted_position > 0:
                input_ids = torch.cat([input_ids, draft_tokens[:, :accepted_position]], dim=1)
                attention_mask = torch.cat([attention_mask, torch.ones((1, accepted_position), device=input_ids.device)], dim=1)
                total_tokens_generated += accepted_position

            # 4. For rejected tokens or if all tokens are accepted, generate a new token with the target model
            with torch.no_grad():
                outputs = self.target_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                new_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)

            input_ids = torch.cat([input_ids, new_token], dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=1)
            total_tokens_generated += 1

            # 5. Stop when max_tokens is reached or an EOS token is generated
            if new_token[0, -1] == eos_token_id:
                break

        # Calculate performance metrics
        elapsed_time = time.time() - start_time
        acceptance_rate = total_draft_tokens_accepted / total_draft_tokens_proposed if total_draft_tokens_proposed > 0 else 0

        print(f"Generated {total_tokens_generated - prompt_length} tokens in {elapsed_time:.2f} seconds")
        print(f"Tokens per second: {(total_tokens_generated - prompt_length) / elapsed_time:.2f}")
        print(f"Draft token acceptance rate: {acceptance_rate:.2%}")

        return self.target_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def benchmark(self, prompt: str, max_tokens: int = 100,
                  num_runs: int = 3, compare_baseline: bool = True) -> Dict:
        """
        Benchmark the speculative decoder against baseline decoding.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate.
            num_runs: Number of benchmark runs.
            compare_baseline: Whether to compare with baseline (non-speculative) decoding.

        Returns:
            Dictionary with benchmark results.
        """
        results = {
            "speculative": {"times": [], "tokens_per_second": []},
            "baseline": {"times": [], "tokens_per_second": []} if compare_baseline else None
        }

        # Benchmark speculative decoding.
        for _ in range(num_runs):
            start_time = time.time()
            output = self.speculative_decode(prompt, max_tokens=max_tokens)
            elapsed = time.time() - start_time
            prompt_len = len(self.target_tokenizer(prompt)["input_ids"])
            output_tokens = len(self.target_tokenizer.encode(output)) - prompt_len
            tps = output_tokens / elapsed
            results["speculative"]["times"].append(elapsed)
            results["speculative"]["tokens_per_second"].append(tps)

        # Benchmark baseline decoding.
        if compare_baseline:
            for _ in range(num_runs):
                inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                start_time = time.time()
                with torch.no_grad():
                    output_ids = self.target_model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.shape[1] + max_tokens,
                        do_sample=False,
                        pad_token_id=self.target_tokenizer.pad_token_id
                    )
                elapsed = time.time() - start_time
                output_tokens = output_ids.shape[1] - input_ids.shape[1]
                tps = output_tokens / elapsed
                results["baseline"]["times"].append(elapsed)
                results["baseline"]["tokens_per_second"].append(tps)

        for method in results.keys():
            if results[method] is not None:
                avg_time = sum(results[method]["times"]) / num_runs
                avg_tps = sum(results[method]["tokens_per_second"]) / num_runs
                results[method]["avg_time"] = avg_time
                results[method]["avg_tokens_per_second"] = avg_tps

        if compare_baseline:
            speedup = results["baseline"]["avg_time"] / results["speculative"]["avg_time"]
            results["speedup"] = speedup
            results["latency_reduction"] = (1 - results["speculative"]["avg_time"] / results["baseline"]["avg_time"]) * 100
            # print(f"Speculative decoding speedup: {speedup:.2f}x")
            # print(f"Latency reduction: {results['latency_reduction']:.2f}%")

        return results

## Test

In [10]:
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model

# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu",
    precision_=torch.float16,
    cache=True
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: EleutherAI/pythia-1.4b-deduped
Loading draft model: EleutherAI/pythia-160m-deduped

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 100 tokens in 1.52 seconds
Tokens per second: 65.92
Draft token acceptance rate: 90.00%
Generated 100 tokens in 1.49 seconds
Tokens per second: 67.18
Draft token acceptance rate: 90.00%
Generated 100 tokens in 1.53 seconds
Tokens per second: 65.32
Draft token acceptance rate: 90.00%
Average speculative decoding time: 1.51 seconds
Average speculative tokens per second: 66.08
Average baseline decoding time: 2.23 seconds
Average baseline tokens per second: 45.10
Speedup: 1.47x
Latency reduction: 32.07%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 97 tokens in 1.31 seconds
Tokens per second: 73.94
Draft token acceptance rate: 97.78%
Generated 97 tokens in 1.33 seconds
Tokens per second: 72.78
Draft token acceptance rate: 97.78%
Generated 97 tokens 

## Bonus

In [11]:
target_model_name = "gpt2-xl"  # Larger target model
draft_model_name = "gpt2"   # Smaller draft model

# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu",
    precision_=torch.float32,
    cache=True
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: gpt2-xl




Loading draft model: gpt2

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is
Generated 94 tokens in 15.71 seconds
Tokens per second: 5.98
Draft token acceptance rate: 94.44%
Generated 94 tokens in 9.19 seconds
Tokens per second: 10.22
Draft token acceptance rate: 94.44%
Generated 94 tokens in 9.16 seconds
Tokens per second: 10.26
Draft token acceptance rate: 94.44%
Average speculative decoding time: 11.36 seconds
Average speculative tokens per second: 8.82
Average baseline decoding time: 44.10 seconds
Average baseline tokens per second: 2.27
Speedup: 3.88x
Latency reduction: 74.25%

Benchmarking Prompt 2:
Prompt: Write a short story about a robot learning to feel emotions:
Generated 91 tokens in 10.24 seconds
Tokens per second: 8.88
Draft token acceptance rate: 81.00%
Generated 91 tokens in 10.33 seconds
Tokens per second: 8.81
Draft token acceptance rate: 81.00%
Generated 91 tokens in 10.25 seconds
Tokens per second: 8.88
Draft token acceptance rate: 81.00%
Avera