# CSE 234 Programming Assignment 3: Speculative Decoding

## Setup

In [2]:
import os
import torch
import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from typing import List, Tuple, Dict, Optional

  from .autonotebook import tqdm as notebook_tqdm


## Speculative Decoding

In [3]:
class SpeculativeDecoder:
    def __init__(self, target_model_name: str, draft_model_name: str, device: str = "cuda"):
        """
        Initialize the speculative decoder with target and draft models.

        Args:
            target_model_name: HuggingFace model ID for the larger target model.
            draft_model_name: HuggingFace model ID for the smaller draft model.
            device: Device to run models on ("cuda" or "cpu").
        """
        self.device = device
        self.target_model, self.target_tokenizer = self.initialize_target_model(target_model_name)
        self.draft_model, self.draft_tokenizer = self.initialize_draft_model(draft_model_name)

        # Ensure tokenizers are compatible
        assert self.target_tokenizer.vocab == self.draft_tokenizer.vocab, "Tokenizers must be compatible"

    def initialize_target_model(self, model_name: str):
        """Initialize the larger target model with caching enabled and proper pad token."""
        print(f"Loading target model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement target model initialization
        # 1. Set the pad token if it doesn't exist
        # 2. Load the model with appropriate settings for inference
        # 3. Enable any optimizations that might help with performance
        # baseline: directly load the model
        # pad token:
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model = AutoModelForCausalLM.from_pretrained(model_name)
        model.to(self.device)
        return model, tokenizer

    def initialize_draft_model(self, model_name: str):
        """
        Initialize a smaller, faster draft model with proper pad token.
        Uses lower precision and additional optimizations.
        """
        print(f"Loading draft model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # TODO: Implement draft model initialization
        # 1. Set the pad token if it doesn't exist
        # 2. Load the model with appropriate settings for inference
        # 3. Enable any optimizations that might help with performance
        if not hasattr(self, 'target_model'):
            self.target_model, self.target_tokenizer = self.initialize_target_model(self.target_model_name)

        # pad token:
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        # baseline: directly load the model 
        model = AutoModelForCausalLM.from_pretrained(model_name)
        model.to(self.device)

        return model, tokenizer

    def generate_draft_tokens(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                             num_speculative_tokens: int = 10) -> torch.Tensor:
        """
        Generate speculative tokens in one forward call using the draft model.

        Args:
            input_ids: Input token IDs (tensor of shape [1, seq_len]).
            attention_mask: Corresponding attention mask.
            num_speculative_tokens: Number of tokens to speculate.

        Returns:
            Tensor of shape [1, num_speculative_tokens] containing the draft tokens.
        """
        # TODO: Implement draft token generation
        # 1. Use the draft model to generate tokens
        # 2. Extract only the new tokens (not including the input)
        # 3. Return the newly generated tokens
        # baseline: directly generate tokens
        if not hasattr(self, 'draft_model'):
            self.draft_model, self.draft_tokenizer = self.initialize_draft_model(self.draft_model_name)
        
        max_tokens = num_speculative_tokens + input_ids.shape[-1]
        generation_config = GenerationConfig(
            max_new_tokens=num_speculative_tokens,
            do_sample=True,
            temperature=0.3,
            pad_token_id=self.draft_tokenizer.pad_token_id,
            eos_token_id=self.draft_tokenizer.eos_token_id,
            bos_token_id=self.draft_tokenizer.bos_token_id,
        )
        self.draft_model.generation_config = generation_config
        draft = self.draft_model.generate(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            # use_cache=True
        )

        return draft[:, input_ids.shape[-1]:]

    def verify_tokens_vectorized(self, input_ids: torch.Tensor, draft_tokens: torch.Tensor,
                               attention_mask: torch.Tensor) -> Tuple[List[int], int]:
        """
        Vectorized verification: verify all draft tokens in one forward pass using the target model.

        Args:
            input_ids: The current input token IDs (shape [1, L]).
            draft_tokens: Draft tokens from the draft model (shape [1, k]).
            attention_mask: The current attention mask for input_ids.

        Returns:
            accepted_tokens: List of accepted token IDs.
            accepted_position: Index of the first rejected token (if all accepted, equals draft_tokens.shape[1]).
        """
        # TODO: Implement efficient verification of draft tokens
        # 1. Run target model on input_ids concatenated with draft_tokens
        # 2. Extract the logits for positions where draft tokens would be predicted
        # 3. Compare target model predictions with draft tokens
        # 4. Determine how many consecutive tokens were accepted before first mismatch
        target_input_ids = torch.cat([input_ids, draft_tokens], dim=-1)
        target_attention_mask = torch.cat([attention_mask, torch.ones_like(draft_tokens)], dim=-1)
        target_input_ids = target_input_ids.to(self.device)
        target_attention_mask = target_attention_mask.to(self.device)
        with torch.no_grad():
            logits = self.target_model(target_input_ids, attention_mask=target_attention_mask).logits
        
        draft_start_idx = input_ids.shape[-1]-1
        target_draft_logits = logits[:, draft_start_idx:-1, :]

        # get target tokens
        target_predicted_tokens = target_draft_logits.argmax(dim=-1)

        # matches tokens
        match_mask = ~(target_predicted_tokens == draft_tokens).int()
        if match_mask.sum() != 0:
            accept_position = torch.argmax(match_mask, dim=-1)
        else:
            accept_position = draft_tokens.shape[-1]

        accepted_tokens = draft_tokens[:,:accept_position]
        return accepted_tokens, accept_position

    def speculative_decode(self, prompt: str, max_tokens: int = 100,
                          num_speculative_tokens: int = 15) -> str:
        """
        Main speculative decoding algorithm with vectorized verification.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate (excluding prompt).
            num_speculative_tokens: Number of tokens to speculate per iteration.

        Returns:
            Generated text.
        """
        # Tokenize prompt
        inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        prompt_length = input_ids.shape[1]

        # Initialize counters for performance tracking
        total_tokens_generated = prompt_length
        total_draft_tokens_proposed = 0
        total_draft_tokens_accepted = 0
        start_time = time.time()

        # TODO: Implement the core speculative decoding loop
        # 1. Generate draft tokens using the draft model
        # 2. Verify draft tokens using the target model
        # 3. Accept verified tokens and append to the sequence
        # 4. For rejected tokens or if all tokens are accepted, generate a new token with the target model
        # 5. Stop when max_tokens is reached or an EOS token is generated
        
        while total_tokens_generated - prompt_length < max_tokens and input_ids[0, -1] != self.target_tokenizer.eos_token_id:   
            # draft generation:
            draft_ids = self.generate_draft_tokens(input_ids, attention_mask, num_speculative_tokens)
            total_tokens_generated += num_speculative_tokens
            total_draft_tokens_proposed += num_speculative_tokens

            # verification:
            accepted_draft_ids, accepted_position = self.verify_tokens_vectorized(input_ids, draft_ids, attention_mask)
            total_draft_tokens_accepted += accepted_draft_ids.shape[-1]

            # append accepted tokens to the sequence
            input_ids = torch.cat([input_ids, accepted_draft_ids], dim=-1)
            attention_mask = torch.cat([attention_mask, torch.ones_like(accepted_draft_ids)], dim=-1)

            # update counters
            total_tokens_generated = input_ids.shape[-1]

        # Calculate performance metrics
        elapsed_time = time.time() - start_time
        acceptance_rate = total_draft_tokens_accepted / total_draft_tokens_proposed if total_draft_tokens_proposed > 0 else 0

        print(f"Generated {total_tokens_generated - prompt_length} tokens in {elapsed_time:.2f} seconds")
        print(f"Tokens per second: {(total_tokens_generated - prompt_length) / elapsed_time:.2f}")
        print(f"Draft token acceptance rate: {acceptance_rate:.2%}")

        return self.target_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    def benchmark(self, prompt: str, max_tokens: int = 100,
                  num_runs: int = 3, compare_baseline: bool = True) -> Dict:
        """
        Benchmark the speculative decoder against baseline decoding.

        Args:
            prompt: Input text.
            max_tokens: Maximum number of tokens to generate.
            num_runs: Number of benchmark runs.
            compare_baseline: Whether to compare with baseline (non-speculative) decoding.

        Returns:
            Dictionary with benchmark results.
        """
        results = {
            "speculative": {"times": [], "tokens_per_second": []},
            "baseline": {"times": [], "tokens_per_second": []} if compare_baseline else None
        }

        # Benchmark speculative decoding.
        for _ in range(num_runs):
            start_time = time.time()
            output = self.speculative_decode(prompt, max_tokens=max_tokens)
            elapsed = time.time() - start_time
            prompt_len = len(self.target_tokenizer(prompt)["input_ids"])
            output_tokens = len(self.target_tokenizer.encode(output)) - prompt_len
            tps = output_tokens / elapsed
            results["speculative"]["times"].append(elapsed)
            results["speculative"]["tokens_per_second"].append(tps)

        # Benchmark baseline decoding.
        if compare_baseline:
            for _ in range(num_runs):
                inputs = self.target_tokenizer(prompt, return_tensors="pt", padding=True)
                input_ids = inputs["input_ids"].to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                start_time = time.time()
                with torch.no_grad():
                    output_ids = self.target_model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.shape[1] + max_tokens,
                        do_sample=False,
                        pad_token_id=self.target_tokenizer.pad_token_id
                    )
                elapsed = time.time() - start_time
                output_tokens = output_ids.shape[1] - input_ids.shape[1]
                tps = output_tokens / elapsed
                results["baseline"]["times"].append(elapsed)
                results["baseline"]["tokens_per_second"].append(tps)

        for method in results.keys():
            if results[method] is not None:
                avg_time = sum(results[method]["times"]) / num_runs
                avg_tps = sum(results[method]["tokens_per_second"]) / num_runs
                results[method]["avg_time"] = avg_time
                results[method]["avg_tokens_per_second"] = avg_tps

        if compare_baseline:
            speedup = results["baseline"]["avg_time"] / results["speculative"]["avg_time"]
            results["speedup"] = speedup
            results["latency_reduction"] = (1 - results["speculative"]["avg_time"] / results["baseline"]["avg_time"]) * 100
            # print(f"Speculative decoding speedup: {speedup:.2f}x")
            # print(f"Latency reduction: {results['latency_reduction']:.2f}%")

        return results

## Test

In [4]:
target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")

Loading target model: EleutherAI/pythia-1.4b-deduped


The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


Loading draft model: EleutherAI/pythia-160m-deduped

Benchmarking Prompt 1:
Prompt: The future of Artificial Intelligence is


KeyboardInterrupt: 

In [None]:
# # test function
# target_model_name = "EleutherAI/pythia-1.4b-deduped"  # Larger target model
# draft_model_name = "EleutherAI/pythia-160m-deduped"   # Smaller draft model


# # Initialize speculative decoder
# decoder = SpeculativeDecoder(
#     target_model_name=target_model_name,
#     draft_model_name=draft_model_name,
#     device="cuda" if torch.cuda.is_available() else "cpu"
# )

# # Test prompts
# test_prompts = [
#     "The future of Artificial Intelligence is",
#     "Write a short story about a robot learning to feel emotions:",
#     "Write the lyrics to the song 'Happy Birthday'."
# ]


Loading target model: EleutherAI/pythia-1.4b-deduped


The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


Loading draft model: EleutherAI/pythia-160m-deduped


In [None]:
# decoder.speculative_decode(test_prompts[0])

Generated 0 tokens in 0.19 seconds
Tokens per second: 0.00
Draft token acceptance rate: 0.00%


'The future of Artificial Intelligence is'

In [41]:
# def generate_draft_tokens(model, draft_tokenizer,input_ids: torch.Tensor, attention_mask: torch.Tensor,
#                             num_speculative_tokens: int = 10) -> torch.Tensor:
#     """
#     Generate speculative tokens in one forward call using the draft model.

#     Args:
#         input_ids: Input token IDs (tensor of shape [1, seq_len]).
#         attention_mask: Corresponding attention mask.
#         num_speculative_tokens: Number of tokens to speculate.

#     Returns:
#         Tensor of shape [1, num_speculative_tokens] containing the draft tokens.
#     """
#     # TODO: Implement draft token generation
#     # 1. Use the draft model to generate tokens
#     # 2. Extract only the new tokens (not including the input)
#     # 3. Return the newly generated tokens
#     # baseline: directly generate tokens

    
#     max_tokens = num_speculative_tokens + input_ids.shape[-1]
#     generation_config = GenerationConfig(
#         max_new_tokens=num_speculative_tokens,
#         do_sample=True,
#         pad_token_id=draft_tokenizer.pad_token_id,
#         eos_token_id=draft_tokenizer.eos_token_id,
#         bos_token_id=draft_tokenizer.bos_token_id,
#     )
#     model.generation_config = generation_config
#     draft = model.generate(
#         input_ids=input_ids, 
#         attention_mask=attention_mask, 
#         # use_cache=True
#     )

#     return draft[:, input_ids.shape[-1]:]

# inputs = decoder.target_tokenizer(test_prompts[0], return_tensors="pt", padding=True)
# input_ids = inputs["input_ids"].to(decoder.device)
# attention_mask = inputs["attention_mask"].to(decoder.device)
# prompt_length = input_ids.shape[1]

# # Initialize counters for performance tracking
# total_tokens_generated = prompt_length
# total_draft_tokens_proposed = 0
# total_draft_tokens_accepted = 0
# start_time = time.time()


In [None]:
# total_tokens = 0
# while total_tokens < 100 and input_ids[0, -1] != decoder.target_tokenizer.eos_token_id:
#     draft_ids = generate_draft_tokens(decoder.draft_model, decoder.draft_tokenizer, input_ids, attention_mask, 10)

#     # verification:
#     target_input_ids = torch.cat([input_ids, draft_ids], dim=-1)
#     target_attention_mask = torch.cat([attention_mask, torch.ones_like(draft_ids)], dim=-1)
#     target_input_ids = target_input_ids.to(decoder.device)
#     target_attention_mask = target_attention_mask.to(decoder.device)
#     with torch.no_grad():
#         logits = decoder.target_model(target_input_ids, attention_mask=target_attention_mask).logits
        
#     start_idx = input_ids.shape[-1]-1
#     target_draft_logits = logits[:, start_idx:-1, :]

#     # get target tokens
#     target_predicted_tokens = target_draft_logits.argmax(dim=-1)

#     # matches tokens
#     match_mask = ~(target_predicted_tokens == draft_ids).int()
#     if match_mask.sum() != 0:
#         accept_position = torch.argmax(match_mask, dim=-1)
#     else:
#         accept_position = draft_ids.shape[-1]

#     accepted_tokens = draft_ids[:,:accept_position]

#     # append accepted tokens to the sequence
#     print(accepted_tokens.shape)

#     input_ids = torch.cat([input_ids, accepted_tokens], dim=-1)
#     attention_mask = torch.cat([attention_mask, torch.ones_like(accepted_tokens)], dim=-1)
#     total_tokens = input_ids.shape[-1]



torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([1, 0])
torch.Size([

## Bonus

In [None]:
target_model_name = ...  # Larger target model
draft_model_name = ...   # Smaller draft model


# Initialize speculative decoder
decoder = SpeculativeDecoder(
    target_model_name=target_model_name,
    draft_model_name=draft_model_name,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Test prompts
test_prompts = [
    "The future of Artificial Intelligence is",
    "Write a short story about a robot learning to feel emotions:",
    "Write the lyrics to the song 'Happy Birthday'."
]

# Run benchmark on test prompts
for i, prompt in enumerate(test_prompts):
    print(f"\nBenchmarking Prompt {i+1}:")
    print(f"Prompt: {prompt}")

    results = decoder.benchmark(
        prompt=prompt,
        max_tokens=100,
        num_runs=3,
        compare_baseline=True
    )

    print(f"Average speculative decoding time: {results['speculative']['avg_time']:.2f} seconds")
    print(f"Average speculative tokens per second: {results['speculative']['avg_tokens_per_second']:.2f}")

    if results["baseline"] is not None:
        print(f"Average baseline decoding time: {results['baseline']['avg_time']:.2f} seconds")
        print(f"Average baseline tokens per second: {results['baseline']['avg_tokens_per_second']:.2f}")
        print(f"Speedup: {results['speedup']:.2f}x")
        print(f"Latency reduction: {results['latency_reduction']:.2f}%")