In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F


class SpeculativeDecoding:
    def __init__(
        self,
        draft_model_name="distilgpt2",  # change
        target_model_name="gpt2",  # change
        acceptance_threshold=0.9,
    ):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        # Load models and tokenizers
        print(f"Loading draft model: {draft_model_name}")
        self.draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
        if self.draft_tokenizer.pad_token_id is None:
            self.draft_tokenizer.pad_token = self.draft_tokenizer.eos_token
        self.draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name).to(
            self.device
        )

        print(f"Loading target model: {target_model_name}")
        self.target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
        if self.target_tokenizer.pad_token_id is None:
            self.target_tokenizer.pad_token = self.target_tokenizer.eos_token
        self.target_model = AutoModelForCausalLM.from_pretrained(target_model_name).to(
            self.device
        )

        self.acceptance_threshold = acceptance_threshold

    def generate(self, prompt, max_length=50, n_draft_tokens=4, temperature=0.7):
        """
        Generate text using speculative decoding.
        """
        # Encode prompt
        input_ids = self.draft_tokenizer(prompt, return_tensors="pt").input_ids.to(
            self.device
        )
        current_length = input_ids.shape[1]

        # Generate until max length or EOS
        while current_length < max_length:
            # 1. Generate draft tokens using the draft model
            with torch.no_grad():                
                # Generate n_draft_tokens autoregressively with the draft model
                draft_tokens = []
                draft_probs = []
                current_draft_input = input_ids.clone()
                
                # Greedy decoding
                for _ in range(n_draft_tokens):
                    # Get probabilities for next token
                    current_logits = self.draft_model(current_draft_input).logits[:, -1:]
                    current_probs = F.softmax(current_logits / temperature, dim=-1)
                    draft_probs.append(current_probs)
                    
                    # Sample or take argmax
                    next_token = current_probs.argmax(dim=-1)  # You could also use sampling here
                    draft_tokens.append(next_token)
                    
                    # Append to input for next iteration
                    current_draft_input = torch.cat([current_draft_input, next_token], dim=1)
                
                # Stack all draft tokens
                draft_tokens = torch.cat(draft_tokens, dim=1)  # Shape: (batch_size, n_draft_tokens)
                draft_probs = torch.cat(draft_probs, dim=1)    # Shape: (batch_size, n_draft_tokens, vocab_size)

            # 2. Verify with target model
            proposed_sequence = torch.cat([input_ids, draft_tokens], dim=1)
            with torch.no_grad():
                target_outputs = self.target_model(proposed_sequence)
                target_logits = target_outputs.logits[
                    :, current_length - 1 : current_length + n_draft_tokens - 1
                ]
                target_probs = F.softmax(target_logits / temperature, dim=-1)

            # Calculate acceptance probabilities
            draft_token_probs = torch.gather(
                target_probs, 2, draft_tokens.unsqueeze(-1)
            ).squeeze(-1)
            accepted_mask = draft_token_probs >= self.acceptance_threshold

            # Find first rejection or accept all if no rejections
            first_rejected = torch.where(~accepted_mask[0])[0]
            accept_length = (
                n_draft_tokens if len(first_rejected) == 0 else first_rejected[0].item()
            )

            if accept_length > 0:
                # Add accepted tokens to input
                input_ids = torch.cat(
                    [input_ids, draft_tokens[:, :accept_length]], dim=1
                )
                current_length = input_ids.shape[1]

            # Check for EOS token
            if (input_ids == self.draft_tokenizer.eos_token_id).any():
                break

            # If no tokens were accepted, generate one token from the draft model
            if accept_length == 0:
                input_ids = torch.cat([input_ids, draft_tokens[:, :1]], dim=1)
                current_length = input_ids.shape[1]

        # Decode final output
        output_text = self.draft_tokenizer.decode(
            input_ids[0], skip_special_tokens=True
        )
        return output_text

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
decoder = SpeculativeDecoding()

Using device: cuda
Loading draft model: distilgpt2
Loading target model: gpt2


In [3]:
decoder.generate("Once upon a time,")

'Once upon a time, the world was a place of great beauty and great danger. The world was a place of great danger, and the world was a place of great danger. The world was a place of great danger, and the world was a place of great'