# Speculative Decoding Demonstration on GPT‑2

In this notebook, we will:

1. Install required libraries  
2. Load two models: a **draft** (small) model and a **target** (full) model  
3. Tokenize a prompt  
4. Implement a simple speculative decoding loop:  
   - Use the draft model to propose *k* tokens at once  
   - Use the target model to “verify” which of those *k* tokens match its top‐predictions  
   - Accept the matching prefix, then repeat until we’ve generated the desired length  
5. Compare to standard greedy decoding  
6. Generate sample outputs and observe speed/quality trade‐offs  

Speculative decoding can speed up inference by batching more tokens through the larger model at once, relying on a cheaper “draft” model to propose candidates.


## 1. Install Dependencies

We’ll use Hugging Face **transformers** for both models and **torch** for tensor ops.


In [1]:
!pip install transformers torch --quiet


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## 2. Import Libraries


In [2]:
import time
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer


## 3. Load Models and Tokenizer

- **Draft model**: `distilgpt2` (smaller, faster)  
- **Target model**: `gpt2` (full)  
- Both in evaluation mode on CPU/GPU as available.


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer    = GPT2Tokenizer.from_pretrained("gpt2")
draft_model  = GPT2LMHeadModel.from_pretrained("distilgpt2").to(device).eval()
target_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device).eval()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

## 4. Prepare Prompt & Warm‑Up Target Cache

1. Tokenize prompt  
2. Run the target **once** to get `past_key_values` for caching  
3. Keep a copy of `input_ids` for generation


In [4]:
prompt = "In a distant future, AI and humans"
enc = tokenizer(prompt, return_tensors="pt").to(device)

# Warm up target cache
with torch.no_grad():
    out = target_model(**enc, use_cache=True)
past = out.past_key_values  # tuple of length num_layers

# We'll build on this:
generated = enc.input_ids.clone()


## 5. Speculative Decoding Function

**Logic**:  
- While we need more tokens:  
  1. Ask **draft** to propose `k` tokens with its own `generate(use_cache=True)`  
  2. Take those `k` tokens, run **one** forward of **target** on **only** those tokens, passing in `past`  
  3. Compare draft’s proposals vs. target’s top‑1 at each position  
  4. Accept longest matching prefix, update `generated` & `past`  
  5. If none match, fall back to a single‑token greedy step from the target  
- Return the full generated sequence


In [5]:
def speculative_decode(draft, target, generated, past, max_new_tokens=20, k=5):
    """
    draft.generate: proposals for k tokens
    target(..., past_key_values): scores only new tokens
    """
    total_generated = 0

    while total_generated < max_new_tokens:
        # 1. Draft proposes k tokens
        draft_out = draft.generate(
            generated,
            max_new_tokens=min(k, max_new_tokens - total_generated),
            do_sample=False,
            use_cache=True,
            return_dict_in_generate=True
        )
        # Extract new tokens & skip past tokens
        new_tokens = draft_out.sequences[0, generated.shape[-1]:]

        # 2. Target scores only new tokens (one forward)
        with torch.no_grad():
            tgt_out = target(
                input_ids=new_tokens.unsqueeze(0),
                past_key_values=past,
                use_cache=True
            )
        logits = tgt_out.logits  # [1, k, vocab]
        new_past = tgt_out.past_key_values

        # 3. Compute target’s greedy next tokens
        tgt_next = logits.argmax(-1).squeeze(0)  # [k]

        # 4. Find longest matching run
        matches = (new_tokens.cpu() == tgt_next.cpu()).tolist()
        # first mismatch index
        if False in matches:
            first_mismatch = matches.index(False)
        else:
            first_mismatch = len(matches)

        # 5a. If any match, accept them
        if first_mismatch > 0:
            to_accept = new_tokens[:first_mismatch].unsqueeze(0)
            generated = torch.cat([generated, to_accept.to(device)], dim=-1)
            # update past to the new_past at the matching point
            # past holds key/values for entire generated; new_past holds for extra k tokens
            past = new_past
            total_generated += first_mismatch
        else:
            # 5b. No match: fallback to single-token greedy
            with torch.no_grad():
                tgt_step = target(
                    input_ids=generated[:, -1:].to(device),
                    past_key_values=past,
                    use_cache=True
                )
            # take its greedy next
            step_token = tgt_step.logits[:, -1, :].argmax(-1).unsqueeze(0)
            generated = torch.cat([generated, step_token.to(device)], dim=-1)
            past = tgt_step.past_key_values
            total_generated += 1

    return generated


## 6. Compare to Standard Greedy Decoding

We’ll measure time and outputs from both methods.


In [6]:
# Standard greedy
start = time.time()
greedy = target_model.generate(
    enc.input_ids, max_new_tokens=20, do_sample=False, use_cache=True
)
greedy_time = time.time() - start

# Speculative
start = time.time()
spec = speculative_decode(draft_model, target_model, generated.clone(), past, max_new_tokens=20, k=5)
spec_time = time.time() - start

print(f"Greedy time:      {greedy_time:.2f}s")
print(f"Speculative time: {spec_time:.2f}s\n")

print("=== Greedy Output ===")
print(tokenizer.decode(greedy[0], skip_special_tokens=True))
print("\n=== Speculative Output ===")
print(tokenizer.decode(spec[0], skip_special_tokens=True))


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask

Greedy time:      1.23s
Speculative time: 8.89s

=== Greedy Output ===
In a distant future, AI and humans will be able to communicate with each other, and the AI will be able to communicate with humans.

=== Speculative Output ===
In a distant future, AI and humans would be able to communicate with each other.
The AI would be able to communicate with each other
