<a href="https://colab.research.google.com/github/TheRadDani/Speculative-LLM-Decoding-Draft-Models/blob/main/Speculative_Decoding_Draft_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [60]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import numpy as np
import time
import torch.nn.functional as F # For softmax

In [61]:
device = "cuda" if torch.cuda.is_available() else "cpu"

target_model_name = "EleutherAI/gpt-neo-125m"
draft_model_name = "distilbert/distilgpt2"

# Load tokenizer first to determine the vocabulary size consistently
tokenizer = AutoTokenizer.from_pretrained(target_model_name)

In [62]:
# --- Canonical Vocabulary Size Handling ---

CANONICAL_VOCAB_SIZE = len(tokenizer)

# Check if pad_token is already defined in the tokenizer
if tokenizer.pad_token is None:
    # If not, add it using the EOS token as a common practice for causal LMs
    # This will increase the tokenizer's vocabulary size by 1.
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
    CANONICAL_VOCAB_SIZE = len(tokenizer) # Update canonical size
else:
    # Ensure CANONICAL_VOCAB_SIZE is correct even if pad_token was already present
    CANONICAL_VOCAB_SIZE = len(tokenizer)

print(f"Final tokenizer vocab size (len): {len(tokenizer)}")
print(f"Tokenizer pad_token_id: {tokenizer.pad_token_id}")
print(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}")
print(f"CANONICAL_VOCAB_SIZE for validation: {CANONICAL_VOCAB_SIZE}")

Final tokenizer vocab size (len): 50257
Tokenizer pad_token_id: 50256
Tokenizer eos_token_id: 50256
CANONICAL_VOCAB_SIZE for validation: 50257


In [64]:
# Load both models using the helper function
target_model = load_and_resize_model(target_model_name, CANONICAL_VOCAB_SIZE)
draft_model = load_and_resize_model(draft_model_name, CANONICAL_VOCAB_SIZE)

Loaded EleutherAI/gpt-neo-125m. Final vocab size: 50257, Embedding layer size: 50257
Loaded distilbert/distilgpt2. Final vocab size: 50257, Embedding layer size: 50257


In [63]:
# Helper function to load model and adjust its embedding layer size
def load_and_resize_model(model_name, canonical_vocab_size):
    """
    Loads a causal language model and ensures its embedding layer matches
    the canonical vocabulary size.
    """
    config = AutoConfig.from_pretrained(model_name)

    # Important: Adjust vocab_size in config before loading, if it's smaller.
    # This guides the model's initialization of its embedding layer.
    if config.vocab_size < canonical_vocab_size:
        print(f"Adjusting {model_name} config vocab_size from {config.vocab_size} to {canonical_vocab_size}")
        config.vocab_size = canonical_vocab_size

    model = AutoModelForCausalLM.from_pretrained(model_name, config=config)

    # After loading, explicitly resize embeddings if they still don't match.
    # This can happen if the loaded model's weights don't perfectly align with the config changes.
    if model.get_input_embeddings().num_embeddings < canonical_vocab_size:
        print(f"Resizing {model_name} embeddings from {model.get_input_embeddings().num_embeddings} to {canonical_vocab_size}")
        model.resize_token_embeddings(canonical_vocab_size)

    model.to(device)
    print(f"Loaded {model_name}. Final vocab size: {model.config.vocab_size}, Embedding layer size: {model.get_input_embeddings().num_embeddings}")
    return model

In [65]:
# Helper function for robust sampling
def sample_next_token(logits, temperature=1.0, top_k=0, top_p=1.0, model_vocab_size=None, tokenizer_ref=None):
    """
    Samples the next token from the logits, with optional temperature, Top-K, and Top-P sampling.
    Includes robust validation for sampled token IDs.
    """
    if temperature == 0.0: # Greedy decoding
        next_token_id = torch.argmax(logits, dim=-1).item()
    else:
        # Apply temperature
        logits = logits / temperature

        # Top-K sampling
        if top_k > 0:
            top_k_actual = min(top_k, logits.size(-1))
            values, _ = torch.topk(logits, top_k_actual)
            min_value = values[:, -1].unsqueeze(-1)
            logits = torch.where(logits < min_value, torch.full_like(logits, -float('Inf')), logits)

        # Top-P (nucleus) sampling
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift the indices to the right to keep the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0

            # Set logits of removed tokens to -Inf
            logits = logits.scatter_(-1, sorted_indices[sorted_indices_to_remove], float('-Inf'))

        # Convert logits to probabilities
        probabilities = F.softmax(logits, dim=-1)

        # Handle cases where probabilities might become all zero after aggressive filtering
        probabilities = torch.nan_to_num(probabilities, nan=0.0, posinf=0.0, neginf=0.0)

        if probabilities.sum().item() == 0.0:
            print("Warning: All probabilities zero after filtering. Falling back to EOS.")
            # Ensure tokenizer_ref is provided for this fallback
            if tokenizer_ref is not None and tokenizer_ref.eos_token_id is not None:
                return tokenizer_ref.eos_token_id
            else:
                # If no EOS token, return a safe default like 0 (first token)
                return 0

        # Sample from the (possibly filtered) distribution
        next_token_id = torch.multinomial(probabilities, num_samples=1).item()

    # --- Robust validation for sampled token ID ---
    if model_vocab_size is not None:
        if not (0 <= next_token_id < model_vocab_size):
            print(f"!!! CRITICAL ERROR: Sampled token ID {next_token_id} is out of vocabulary range [{0}, {model_vocab_size-1}]")
            print(f"Logits shape: {logits.shape}")
            print(f"Probabilities sum: {probabilities.sum().item():.4f}")
            top_probs, top_indices = torch.topk(probabilities, k=min(10, probabilities.size(-1)))
            print(f"Top 10 Probs: {top_probs.tolist()}")
            print(f"Top 10 Indices: {top_indices.tolist()}")
            raise IndexError("Sampled token ID out of model's vocabulary range.")

    return next_token_id

In [66]:
def autoregressive_decode_with_sampling(prompt: str, target_model, tokenizer,
                                         max_new_tokens: int = 50,
                                         temperature: float = 1.0,
                                         top_k: int = 0,
                                         top_p: float = 1.0):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated_tokens = []
    past_key_values = None

    start_time = time.time()
    for _ in range(max_new_tokens):
        # --- Validate input_ids before passing to model ---
        if input_ids.numel() > 0 and (input_ids.min() < 0 or input_ids.max() >= CANONICAL_VOCAB_SIZE):
            print(f"!!! ERROR: AR input_ids min: {input_ids.min().item()}, max: {input_ids.max().item()}")
            print(f"AR input_ids: {input_ids}")
            raise IndexError("AR input_ids contain out-of-range token for target model.")

        with torch.no_grad():
            outputs = target_model(input_ids, past_key_values=past_key_values, use_cache=True)
            logits = outputs.logits[:, -1, :]

            past_key_values = outputs.past_key_values

            next_token_id = sample_next_token(logits, temperature=temperature, top_k=top_k, top_p=top_p,
                                             model_vocab_size=CANONICAL_VOCAB_SIZE, tokenizer_ref=tokenizer)

            generated_tokens.append(next_token_id)
            input_ids = torch.tensor([[next_token_id]]).to(device)

            if next_token_id == tokenizer.eos_token_id:
                break

    end_time = time.time()
    full_sequence = tokenizer.decode(tokenizer.encode(prompt) + generated_tokens, skip_special_tokens=True)
    return full_sequence, len(generated_tokens), (end_time - start_time)

In [68]:
def speculative_decode(prompt: str, target_model, draft_model, tokenizer,
                       max_new_tokens: int = 50, speculative_lookahead: int = 5,
                       temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0):
    input_ids_initial_prompt = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated_tokens = []

    target_past_key_values = None
    draft_past_key_values = None

    current_validated_prefix_ids = input_ids_initial_prompt

    n_accepted_tokens_total = 0
    n_target_model_calls_total = 0
    start_time = time.time()

    while len(generated_tokens) < max_new_tokens:
        initial_cycle_prefix_length = current_validated_prefix_ids.shape[1]

        # Draft Phase: Generate speculative tokens ---
        # Draft model's KV cache is recomputed from scratch based on the current_validated_prefix_ids
        # for each speculative cycle to avoid any potential misalignment.
        temp_draft_input_ids = current_validated_prefix_ids.clone()
        current_draft_past_key_values = None # Force recompute draft KV cache for this cycle

        draft_proposed_tokens = []
        draft_logits_history = []

        for i in range(speculative_lookahead):
            if temp_draft_input_ids.numel() == 0:
                break

            if temp_draft_input_ids.min() < 0 or temp_draft_input_ids.max() >= CANONICAL_VOCAB_SIZE:
                print(f"!!! ERROR: Draft model input_ids contain out-of-range token prior to generation step {i}.")
                print(f"Min ID: {temp_draft_input_ids.min().item()}, Max ID: {temp_draft_input_ids.max().item()}")
                print(f"Draft model vocab size: {CANONICAL_VOCAB_SIZE}")
                print(f"Offending input_ids: {temp_draft_input_ids}")
                raise IndexError("Draft model input_ids out of vocabulary range.")

            with torch.no_grad():
                draft_outputs = draft_model(temp_draft_input_ids, past_key_values=current_draft_past_key_values, use_cache=True)
                draft_logits = draft_outputs.logits[:, -1, :]
                current_draft_past_key_values = draft_outputs.past_key_values # Update draft's KV cache for next proposal

                draft_probs = torch.softmax(draft_logits, dim=-1)
                next_draft_token = torch.multinomial(draft_probs, num_samples=1).item()

                if not (0 <= next_draft_token < CANONICAL_VOCAB_SIZE):
                    print(f"!!! ERROR: Draft model generated out-of-vocab token: {next_draft_token}")
                    print(f"Draft model vocab size: {CANONICAL_VOCAB_SIZE}")
                    raise IndexError("Draft model generated token ID out of range.")

                draft_proposed_tokens.append(next_draft_token)
                draft_logits_history.append(draft_logits)

                temp_draft_input_ids = torch.tensor([[next_draft_token]]).to(device)

                if next_draft_token == tokenizer.eos_token_id:
                    break

        if not draft_proposed_tokens:
            print("Draft model proposed no tokens or hit EOS immediately. Falling back to single target model generation.")
            if current_validated_prefix_ids.numel() == 0:
                print("Warning: current_validated_prefix_ids is empty in fallback, cannot generate.")
                break

            if current_validated_prefix_ids.min() < 0 or current_validated_prefix_ids.max() >= CANONICAL_VOCAB_SIZE:
                print(f"!!! ERROR: Fallback input_ids min: {current_validated_prefix_ids.min().item()}, max: {current_validated_prefix_ids.max().item()}")
                print(f"Fallback input_ids: {current_validated_prefix_ids}")
                raise IndexError("Fallback input_ids contain out-of-range token for target model.")

            with torch.no_grad():
                outputs = target_model(current_validated_prefix_ids, past_key_values=target_past_key_values, use_cache=True)
                logits = outputs.logits[:, -1, :]
                target_past_key_values = outputs.past_key_values

                next_token_id = sample_next_token(logits, temperature=temperature, top_k=top_k, top_p=top_p,
                                                  model_vocab_size=CANONICAL_VOCAB_SIZE, tokenizer_ref=tokenizer)

                generated_tokens.append(next_token_id)
                current_validated_prefix_ids = torch.cat((current_validated_prefix_ids, torch.tensor([[next_token_id]]).to(device)), dim=1)
                n_target_model_calls_total += 1

                # --- Aggressive KV cache reset after fallback single-token generation ---
                # This ensures the KV cache is precisely aligned with the *current_validated_prefix_ids*
                # for the start of the *next* speculative cycle.
                target_past_key_values = None

                if next_token_id == tokenizer.eos_token_id:
                    break
            continue

        # Verification Phase: Parallel evaluation by the target model ---
        full_eval_input_ids = torch.cat((current_validated_prefix_ids, torch.tensor([draft_proposed_tokens]).to(device)), dim=1)

        if full_eval_input_ids.numel() > 0 and \
           (full_eval_input_ids.min() < 0 or full_eval_input_ids.max() >= CANONICAL_VOCAB_SIZE):
            print(f"!!! ERROR: Combined input_ids for target model contain out-of-range token.")
            print(f"Min ID: {full_eval_input_ids.min().item()}, Max ID: {full_eval_input_ids.max().item()}")
            print(f"Target model vocab size: {CANONICAL_VOCAB_SIZE}")
            offending_tokens = full_eval_input_ids[ (full_eval_input_ids < 0) | (full_eval_input_ids >= CANONICAL_VOCAB_SIZE) ]
            print(f"Offending tokens: {offending_tokens}")
            print(f"Full eval input IDs: {full_eval_input_ids}")
            raise IndexError("Full eval input_ids for target model out of vocabulary range.")

        n_target_model_calls_total += 1
        with torch.no_grad():
            target_outputs = target_model(full_eval_input_ids, past_key_values=target_past_key_values, use_cache=True)
            target_logits_full_sequence = target_outputs.logits
            # Store the KV cache from this full evaluation for potential slicing upon rejection
            target_past_key_values_from_full_eval = target_outputs.past_key_values

        # Rejection Sampling ---
        accepted_count_in_this_cycle = 0

        target_logits_start_idx_in_full_sequence = current_validated_prefix_ids.shape[1]

        for i, draft_token_id in enumerate(draft_proposed_tokens):
            target_logits_at_idx = target_logits_full_sequence[:, target_logits_start_idx_in_full_sequence + i, :]
            target_prob_for_draft = torch.softmax(target_logits_at_idx, dim=-1)[:, draft_token_id].item()

            draft_logits_at_idx = draft_logits_history[i]
            draft_prob_for_draft = torch.softmax(draft_logits_at_idx, dim=-1)[:, draft_token_id].item()

            acceptance_prob = min(1.0, target_prob_for_draft / (draft_prob_for_draft + 1e-9))

            u = np.random.rand()
            if u <= acceptance_prob:
                generated_tokens.append(draft_token_id)
                accepted_count_in_this_cycle += 1
                n_accepted_tokens_total += 1
                current_validated_prefix_ids = torch.cat((current_validated_prefix_ids, torch.tensor([[draft_token_id]]).to(device)), dim=1)

                if draft_token_id == tokenizer.eos_token_id:
                    # If an accepted token is EOS, we stop immediately.
                    break
            else:
                # Rejection: Sample from the target model's distribution for the current token
                rejection_logits = target_logits_at_idx
                next_token_id_after_rejection = sample_next_token(
                    rejection_logits,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    model_vocab_size=CANONICAL_VOCAB_SIZE,
                    tokenizer_ref=tokenizer
                )

                generated_tokens.append(next_token_id_after_rejection)
                n_accepted_tokens_total += 1
                current_validated_prefix_ids = torch.cat((current_validated_prefix_ids, torch.tensor([[next_token_id_after_rejection]]).to(device)), dim=1)

                # --- Aggressive KV cache reset after rejection ---
                # This is CRUCIAL. It means the KV cache is discarded and recreated from scratch
                # in the *next* speculative cycle, using the newly validated `current_validated_prefix_ids`.
                # This eliminates any potential for misaligned cached states.
                target_past_key_values = None

                if next_token_id_after_rejection == tokenizer.eos_token_id:
                    break
                break # Break from inner 'for' loop after rejection

        # Check if max_new_tokens reached or EOS generated
        if len(generated_tokens) >= max_new_tokens or (generated_tokens and generated_tokens[-1] == tokenizer.eos_token_id):
            break

        # --- Final fallback to ensure progress (if no tokens were accepted and not EOS) ---
        # This ensures we don't get stuck if the draft model is very bad and no accepted tokens.
        # Check if current_validated_prefix_ids hasn't grown since the start of this cycle,
        # implying nothing was accepted from the speculative batch.
        if current_validated_prefix_ids.shape[1] == initial_cycle_prefix_length and len(generated_tokens) < max_new_tokens:
            print("Warning: No tokens accepted in speculative step, and no rejection fallback occurred. Performing single target model step to ensure progress.")
            if current_validated_prefix_ids.numel() == 0:
                print("Warning: current_validated_prefix_ids is empty in final fallback. Cannot generate.")
                break

            if current_validated_prefix_ids.min() < 0 or current_validated_prefix_ids.max() >= CANONICAL_VOCAB_SIZE:
                print(f"!!! ERROR: Final fallback input_ids min: {current_validated_prefix_ids.min().item()}, max: {current_validated_prefix_ids.max().item()}")
                print(f"Final fallback input_ids: {current_validated_prefix_ids}")
                raise IndexError("Final fallback input_ids contain out-of-range token for target model.")

            with torch.no_grad():
                outputs = target_model(current_validated_prefix_ids, past_key_values=target_past_key_values, use_cache=True)
                logits = outputs.logits[:, -1, :]
                target_past_key_values = outputs.past_key_values

                next_token_id = sample_next_token(logits, temperature=temperature, top_k=top_k, top_p=top_p,
                                                   model_vocab_size=CANONICAL_VOCAB_SIZE, tokenizer_ref=tokenizer)

                generated_tokens.append(next_token_id)
                current_validated_prefix_ids = torch.cat((current_validated_prefix_ids, torch.tensor([[next_token_id]]).to(device)), dim=1)
                n_target_model_calls_total += 1

                # --- Aggressive KV cache reset after final fallback ---
                target_past_key_values = None

                if next_token_id == tokenizer.eos_token_id:
                    break


    end_time = time.time()
    # Decode the final sequence
    full_sequence_ids = tokenizer.encode(prompt) + generated_tokens
    full_sequence = tokenizer.decode(full_sequence_ids, skip_special_tokens=True)

    tokens_generated = len(generated_tokens)
    effective_tokens_per_target_pass = tokens_generated / n_target_model_calls_total if n_target_model_calls_total > 0 else 0

    return full_sequence, tokens_generated, (end_time - start_time), effective_tokens_per_target_pass, n_target_model_calls_total

In [None]:
# --- Demonstration ---
prompt = "The quick brown fox jumps over the lazy dog and"
max_tokens_to_generate = 50
speculative_lookahead = 5

sampling_temperature = 0.7
sampling_top_k = 0
sampling_top_p = 1.0

print("\n--- Autoregressive Decoding with Sampling ---")
ar_output, ar_tokens, ar_time = autoregressive_decode_with_sampling(
    prompt, target_model, tokenizer, max_tokens_to_generate,
    temperature=sampling_temperature, top_k=sampling_top_k, top_p=sampling_top_p
)
print(f"Output: '{ar_output}'")
print(f"Tokens Generated: {ar_tokens}")
print(f"Time Taken: {ar_time:.4f} seconds")
print(f"Effective Tokens/Target Pass (AR): {ar_tokens / ar_tokens:.2f}" if ar_tokens > 0 else "N/A")


print("\n--- Speculative Decoding with Sampling ---")
sd_output, sd_tokens, sd_time, sd_effective_tpt, sd_target_calls = speculative_decode(
    prompt, target_model, draft_model, tokenizer, max_tokens_to_generate, speculative_lookahead,
    temperature=sampling_temperature, top_k=sampling_top_k, top_p=sampling_top_p
)
print(f"Output: '{sd_output}'")
print(f"Tokens Generated: {sd_tokens}")
print(f"Time Taken: {sd_time:.4f} seconds")
print(f"Total Target Model Calls: {sd_target_calls}")
print(f"Effective Tokens/Target Pass (SD): {sd_effective_tpt:.2f}")

print("\n--- Comparison ---")
print(f"Autoregressive Time: {ar_time:.4f}s")
print(f"Speculative Decoding Time: {sd_time:.4f}s")
if sd_time > 0:
    print(f"Speedup Factor: {ar_time / sd_time:.2f}x")
# For sampling, outputs are not guaranteed to be identical due to randomness.
# The guarantee is on the *distribution* of outputs matching.
print(f"Autoregressive Output Matches Speculative (content-wise): {ar_output == sd_output}")


prompt_advanced = "In the annals of history, the year 1789 stands out for the French Revolution, a pivotal event that reshaped the political landscape of Europe and beyond. The causes were multifaceted, including"

print("\n--- Autoregressive Decoding (Advanced Prompt) ---")
ar_output_adv, ar_tokens_adv, ar_time_adv = autoregressive_decode_with_sampling(
    prompt_advanced, target_model, tokenizer, max_tokens_to_generate,
    temperature=sampling_temperature, top_k=sampling_top_k, top_p=sampling_top_p
)
print(f"Output: '{ar_output_adv}'")
print(f"Tokens Generated: {ar_tokens_adv}")
print(f"Time Taken: {ar_time_adv:.4f} seconds")

print("\n--- Speculative Decoding (Advanced Prompt) ---")
sd_output_adv, sd_tokens_adv, sd_time_adv, sd_effective_tpt_adv, sd_target_calls_adv = speculative_decode(
    prompt_advanced, target_model, draft_model, tokenizer, max_tokens_to_generate, speculative_lookahead,
    temperature=sampling_temperature, top_k=sampling_top_k, top_p=sampling_top_p
)
print(f"Output: '{sd_output_adv}'")
print(f"Tokens Generated: {sd_tokens_adv}")
print(f"Time Taken: {sd_time_adv:.4f} seconds")
print(f"Total Target Model Calls: {sd_target_calls_adv}")
print(f"Effective Tokens/Target Pass (SD): {sd_effective_tpt_adv:.2f}")

print("\n--- Comparison (Advanced Prompt) ---")
print(f"Autoregressive Time: {ar_time_adv:.4f}s")
print(f"Speculative Decoding Time: {sd_time_adv:.4f}s")
if sd_time_adv > 0:
    print(f"Speedup Factor: {ar_time_adv / sd_time_adv:.2f}x")
print(f"Autoregressive Output Matches Speculative (content-wise): {ar_output_adv == sd_output_adv}")


--- Autoregressive Decoding with Sampling ---
Output: 'The quick brown fox jumps over the lazy dog and goes for a walk, the dog in the park, the park and the dog pulling the leash. They are both the lovers of the horse.

They have been out for about a week and they haven't Hardcore. They are here to stay'
Tokens Generated: 50
Time Taken: 9.8377 seconds
Effective Tokens/Target Pass (AR): 1.00

--- Speculative Decoding with Sampling ---
