<b>1) Minimal working example (HuggingFace): draft + target (speculative decoding)

Concept: a small assistant/draft model proposes multiple tokens; the big target model verifies them in fewer expensive forward passes → higher tokens/sec. This is the classic speculative decoding idea (Leviathan et al.).</b>

<b>A) Greedy speculative decoding (deterministic)</b>

In [None]:
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"

# Example pair from HF docs (swap to your target/draft pair that share tokenizer)
target_id = "HuggingFaceTB/SmolLM-1.7B"
draft_id  = "HuggingFaceTB/SmolLM-135M"

tok = AutoTokenizer.from_pretrained(target_id)
target = AutoModelForCausalLM.from_pretrained(target_id, dtype="auto").to(device).eval()
draft  = AutoModelForCausalLM.from_pretrained(draft_id,  dtype="auto").to(device).eval()

prompt = "Explain speculative decoding in one paragraph for LLM serving."
inputs = tok(prompt, return_tensors="pt").to(device)

with torch.inference_mode():
    t0 = time.perf_counter()
    out = target.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=False,                 # greedy
        assistant_model=draft            # <-- speculative decoding trigger
    )
    t1 = time.perf_counter()

text = tok.decode(out[0], skip_special_tokens=True)
gen_tokens = out.shape[1] - inputs["input_ids"].shape[1]
tok_per_s = gen_tokens / (t1 - t0)

print(text)
print(f"\nGenerated tokens: {gen_tokens}, time: {t1-t0:.3f}s, tokens/sec: {tok_per_s:.2f}")


HF supports speculative decoding via assistant_model and notes it works for greedy + sampling (and has some limitations like batching).

<b>B) Speculative sampling (stochastic)</b>

When you want “sampling behavior” (top-p, temperature), set do_sample=True. HF will resample on validation failures.

In [None]:
with torch.inference_mode():
    t0 = time.perf_counter()
    out = target.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        assistant_model=draft
    )
    t1 = time.perf_counter()

gen_tokens = out.shape[1] - inputs["input_ids"].shape[1]
print(f"tokens/sec: {gen_tokens/(t1-t0):.2f}")
