In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F


class SpeculativeDecoding:
    def __init__(
        self,
        draft_model_name="distilgpt2",
        target_model_name="gpt2-medium",
        acceptance_threshold=0.9,
    ):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        # Load models and tokenizers
        print(f"Loading draft model: {draft_model_name}")
        self.draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
        if self.draft_tokenizer.pad_token_id is None:
            self.draft_tokenizer.pad_token = self.draft_tokenizer.eos_token
        self.draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name).to(
            self.device
        )

        print(f"Loading target model: {target_model_name}")
        self.target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
        if self.target_tokenizer.pad_token_id is None:
            self.target_tokenizer.pad_token = self.target_tokenizer.eos_token
        self.target_model = AutoModelForCausalLM.from_pretrained(target_model_name).to(
            self.device
        )

        self.acceptance_threshold = acceptance_threshold

    def generate(self, prompt, max_length=50, n_draft_tokens=4, temperature=0.7):
        """
        Generate text using speculative decoding.
        """
        # Encode prompt
        input_ids = self.draft_tokenizer(prompt, return_tensors="pt").input_ids.to(
            self.device
        )
        current_length = input_ids.shape[1]

        # Generate until max length or EOS
        while current_length < max_length:
            # 1. Prepare padded sequence
            # TODO: concatenate prompt with padding tokens
            
            # 2. Forward pass through draft model and get draft tokens by doing greedy decoding
            with torch.no_grad():
                # TODO: forward pass through draft model
                # TODO: create draft_tokens by doing greedy decoding

            # 3. Verify with target model
            # use the target model to get the probabilities of the draft tokens
            # then, calculate the acceptance probabilities by masking the draft tokens
            # with the acceptance threshold
            # TODO: implement this step

            # Calculate acceptance probabilities
            draft_token_probs = # TODO: calculate acceptance probabilities
            
            # get mask of accepted tokens
            accepted_mask = draft_token_probs >= self.acceptance_threshold

            # Find first rejection or accept all if no rejections
            # then add accepted tokens to input (concatenate)
            # TODO: implement this step

            # Check for EOS token
            if (input_ids == self.draft_tokenizer.eos_token_id).any():
                break

            # If no tokens were accepted, accept at least the first token 
            # from the draft model
            if accept_length == 0:
                accept_length = 1
                input_ids = torch.cat([input_ids, draft_tokens[:, :1]], dim=1)
                current_length = input_ids.shape[1]

        # Decode final output
        output_text = self.draft_tokenizer.decode(
            input_ids[0], skip_special_tokens=True
        )
        return output_text


decoder = SpeculativeDecoding()

Question: Compare greedy vs greedy with speculative decoding