<a href="https://colab.research.google.com/github/Siddhinita/Speculative_Decoding/blob/main/Speculative_Decoding_Sync.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install bitsandbytes
!pip install accelerate

Collecting bitsandbytes
  Downloading bitsandbytes-0.49.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.49.1-py3-none-manylinux_2_24_x86_64.whl (59.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.49.1


In [43]:
# @title Imports

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from huggingface_hub import login
import time
import random
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
from torchvision.transforms import ToTensor
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
import matplotlib.pyplot as plt
from torchvision import datasets
from google.colab import userdata


In [4]:
# @title Init env

login(token=userdata.get('HF_TOKEN'))
print("CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else None)
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)


CUDA available: True
Device count: 1
Current device: 0
Device name: Tesla T4


In [48]:
# @title Download data
dataset = load_dataset("web_questions")
eval_data = [{"question": row['question'], "answer": row['answers']} for row in dataset['test']]


In [6]:
# @title Download models

# teacher_model_name = "google/gemma-3-4b-it"
# drafter_model_name = "google/gemma-3-1b-it"
teacher_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
drafter_model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_name,
    device_map="auto",
)
teacher_model.eval()
drafter_model = AutoModelForCausalLM.from_pretrained(
    drafter_model_name,
    device_map="auto",
)
drafter_model.eval()
teacher_model.eval()

config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/338 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotar

In [38]:
# @title Test models

eos_token_ids = [tokenizer.eos_token_id]
models = {"teacher_model": teacher_model, "drafter_model": drafter_model}
for name, m in models.items():
  print(f"Testing model {name}")

  messages = [
      {"role": "user", "content": "hi"}
  ]

  inputs = tokenizer.apply_chat_template(
      messages,
      return_tensors="pt",
      add_generation_prompt=True
  ).to(m.device)['input_ids']
  output = m.generate(
      inputs,
      max_new_tokens=120,
      eos_token_id=eos_token_ids,
      pad_token_id=tokenizer.pad_token_id,
  )
  print(tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True).strip())
  print()


Testing model teacher_model
Hello! How can I assist you today?

Testing model drafter_model
Hello! How can I assist you today? Feel free to ask me any questions or let me know if there's anything specific you'd like help with.



In [39]:
# @title Sync Speculative Decoding Algo
# Batch size = 1

def sample_from_model(model, prefix, window, past_kv=None):
  generated_tokens, generated_token_probs  = [], []
  for step in range(window):
    with torch.no_grad():
      if past_kv is not None:
        model_input = prefix[:, -1:]
      else:
         model_input = prefix
      out = model(model_input, use_cache=True, output_hidden_states=False, past_key_values=past_kv)
    logits = out.logits[:, -1, :]
    past_kv = out.past_key_values   # ← KV cache
    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    prefix = torch.cat([prefix, next_token], dim = -1)
    generated_tokens.append(next_token)
    # selected_prob = torch.gather(probs, -1, next_token)
    generated_token_probs.append(probs)
    if next_token in eos_token_ids:
      break
  generated_tokens = torch.cat(generated_tokens, dim=1) # Shape: [1, window]
  generated_token_probs = torch.stack(generated_token_probs, dim=1) # Shape: [1, window, Vocab]
  return generated_tokens, generated_token_probs, prefix, past_kv

def run_parallel_verification(teacher, prefix, drafter_tokens, drafter_probs, past_kv):

  with torch.no_grad():
    if past_kv is not None:
      # the 1st drafter token gets verified by the previous token, so we need logits for it again.
      input_tokens = torch.cat([prefix[:, -1:], drafter_tokens], dim=-1)
      out = teacher(input_tokens, past_key_values=past_kv, use_cache=True)
    else:
      input_tokens = torch.cat([prefix, drafter_tokens], dim=-1)
      out = teacher(input_tokens, past_key_values=past_kv, use_cache=True)
  past_kv = out.past_key_values   # ← KV cache
  # Excluding logits for last token
  logits_for_scoring = out.logits[:, -len(drafter_tokens[0])-1:-1, :]
  probs_for_scoring = torch.softmax(logits_for_scoring, dim=-1)
  accepted_tokens = []
  all_accepted = True
  for i in range(len(drafter_probs[0])):
    teacher_prob_dist = probs_for_scoring[0][i]
    # print(max(teacher_prob_dist))
    # print(tokenizer.decode(torch.argmax(teacher_prob_dist)))
    drafter_prob_dist = drafter_probs[0][i]
    drafter_token = drafter_tokens[0][i]
    teacher_prob = teacher_prob_dist[drafter_token]
    drafter_prob = drafter_prob_dist[drafter_token]
    # print(teacher_prob, drafter_prob)
    if teacher_prob >= drafter_prob or torch.rand(1).to(m.device) < (teacher_prob / drafter_prob):
      # print("Accepting drafter token")
      accepted_tokens.append(drafter_token)
    else:
      # print("Accepting teacher token")
      residual_dist = torch.max(torch.zeros_like(teacher_prob_dist), teacher_prob_dist - drafter_prob_dist)
      if residual_dist.sum() > 0:
        residual_dist = residual_dist / residual_dist.sum()
        new_token = torch.multinomial(residual_dist, 1)
      else:
        new_token = torch.multinomial(teacher_prob_dist, 1)
      accepted_tokens.append(new_token.squeeze())
      all_accepted = False
      break
    # print(accepted_tokens)
  # For example, prefix = 11, drafter_tokens = 5, accepted_tokens = 3
  # We need logits at 11 + 3 -1  = 13
  # -1 -(5 -3) = -3
  # If total logits is 11 + 5, then -3 is 13.
  if all_accepted and not (accepted_tokens[-1].item() in eos_token_ids):
    # print("Accepting bonus token")
    last_accepted_token_idx = -1 - (len(drafter_tokens[0]) - len(accepted_tokens))
    # Logits size is prefix + drafter_tokens OR 1 + drafter_tokens
    final_logits = out.logits[:, last_accepted_token_idx, :]
    final_probs = torch.softmax(final_logits, dim=-1)
    bonus_token = torch.multinomial(final_probs, 1)
    accepted_tokens.append(bonus_token.squeeze())
  return accepted_tokens, past_kv

def spec_decoding(teacher, drafter, prefix, window, max_tokens):
  num_tokens = 0
  drafter_kv = None
  teacher_kv = None
  while True:
    drafter_tokens, drafter_token_probs, new_prefix, drafter_kv = sample_from_model(drafter, prefix, window, drafter_kv)
    accepted_tokens, teacher_kv = run_parallel_verification(teacher, prefix, drafter_tokens, drafter_token_probs, teacher_kv)
    # print(prefix.shape)
    if len(accepted_tokens) != 1:
      accepted_tokens = torch.stack(accepted_tokens, dim = -1).squeeze()[None, :]
    else:
      accepted_tokens = accepted_tokens[0].unsqueeze(0).unsqueeze(0)
    # print(accepted_tokens)
    prefix = torch.cat([prefix, accepted_tokens], dim = -1)
    seq_len = prefix.shape[-1]
    # Look at past_key_values definition in https://huggingface.co/docs/transformers/model_doc/llama
    # Only unprocessed key values are given as model input
    teacher_kv = teacher_kv.crop(seq_len-1)
    drafter_kv = drafter_kv.crop(seq_len-1)
    if prefix[0, -1] in eos_token_ids or seq_len > max_tokens:
      break
  return prefix

In [49]:
# @title Eval utils


TEACHER_KEY = "teacher"
DRAFTER_KEY = "drafter"
SPEC_DECODING_KEY = "spec_decoding"
ASYNC_SPEC_DECODING_KEY = "async_spec_decoding"

def call_model(key, inputs, max_new_tokens):
  if key == TEACHER_KEY:
    st_time = time.time()
    outputs = teacher_model.generate(
        inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False, # Deterministic for eval
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=eos_token_ids,
    )
    end_time = time.time()
  elif key == DRAFTER_KEY:
    st_time = time.time()
    outputs = drafter_model.generate(
        inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False, # Deterministic for eval
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=eos_token_ids,
    )
    end_time = time.time()
  elif key == SPEC_DECODING_KEY:
    st_time = time.time()
    outputs = spec_decoding(teacher_model, drafter_model, inputs, 12, max_new_tokens + len(inputs[0]))
    end_time = time.time()
  elif key == ASYNC_SPEC_DECODING_KEY:
    st_time = time.time()
    outputs = async_spec_decoding(teacher_model, drafter_model, inputs, 12, max_new_tokens + len(inputs[0]))
    end_time = time.time()
  else:
    raise  NotImplementedError(f"Invalid key: {key}")
  response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True).strip()
  output_len = len(outputs[0]) - len(inputs[0])
  time_per_token = (end_time - st_time) / output_len
  return response, time_per_token

def filter_eval_set(eval_data):
  new_eval_data = []
  for item in eval_data:
    a = item['answer']
    q = item['question']
    if (isinstance(a, list) and len(a) == 1):
      new_eval_data.append({"question": q, "answer": a[0]})
    elif isinstance(a, str):
      new_eval_data.append({"question": q, "answer": a})
  print(f"Eval data size: {len(new_eval_data)}")
  random.shuffle(new_eval_data)
  return new_eval_data

def eval_model(key, model_description, limit=10):
  results = []
  for item in tqdm(eval_data[:limit]):
      question = item['question']
      ground_truth = item['answer']
      messages = [
          {"role": "user", "content": f"Answer this question directly. Question: {question}"}
      ]
      input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(m.device)['input_ids']
      with torch.no_grad():
          response, time_per_token = call_model(key, input_ids, 120)

      # Simple exact/substring match check
      is_correct = ground_truth.lower() in response.lower()

      results.append({
          "Question": question,
          "Target": ground_truth,
          "Prediction": response,
          "Time": time_per_token,
          "Correct": "✅" if is_correct else "❌"
      })

  # --- 6. Compare Results ---
  df = pd.DataFrame(results)
  accuracy = (df['Correct'] == "✅").mean() * 100
  avg_time = df['Time'].mean()
  print(f"\n\n=== Evaluation Results for {model_description}")
  print(f"Accuracy on subset: {accuracy:.1f}%")
  print(f"Avg time per token: {avg_time}s")
  print("-" * 60)
  print(df.to_markdown(index=False))

eval_data = filter_eval_set(eval_data)


Eval data size: 1348


# Model Evals

In [54]:
# @title Drafter model eval
LIMIT = 50

eval_model(DRAFTER_KEY, "Qwen 0.5B Drafter", LIMIT)


100%|██████████| 50/50 [04:19<00:00,  5.19s/it]



=== Evaluation Results for Qwen 0.5B Drafter
Accuracy on subset: 26.0%
Avg time per token: 0.06394145128285039s
------------------------------------------------------------
| Question                                                            | Target                                                    | Prediction                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 |      Time | Correct   |
|:----------------------------------------------




In [55]:
# @title Teacher model eval

eval_model(TEACHER_KEY, "Qwen 1.5B Teacher", LIMIT)

100%|██████████| 50/50 [04:10<00:00,  5.02s/it]



=== Evaluation Results for Qwen 1.5B Teacher
Accuracy on subset: 36.0%
Avg time per token: 0.08482621522593595s
------------------------------------------------------------
| Question                                                            | Target                                                    | Prediction                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               |      Time | Correct   |
|:--------------------------------------------------------------------|:----------




In [56]:
# @title Spec Decoding eval (sync)

eval_model(SPEC_DECODING_KEY, "Qwen 1.5B/0.5B Spec decoding", LIMIT)

100%|██████████| 50/50 [12:02<00:00, 14.46s/it]



=== Evaluation Results for Qwen 1.5B/0.5B Spec decoding
Accuracy on subset: 36.0%
Avg time per token: 0.2451794434470147s
------------------------------------------------------------
| Question                                                            | Target                                                    | Prediction                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              |      Time | Correct   |
|:--------------------------------------------------------------------|:---------------------------------


