In [None]:
import torch
import time
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from core.speculative_engine import SpeculativeEngine
from core.draft_model import DraftModel
from core.target_model import TargetModel


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_DRAFT  = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MODEL_TARGET = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(MODEL_TARGET)
draft  = DraftModel(tokenizer, MODEL_DRAFT,  device=device, dtype=torch.float16)
target = TargetModel(tokenizer, MODEL_TARGET, device=device, dtype=torch.float16)

prompt = "The theory of evolution explains"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
MAX_TOKENS = 50


In [None]:
target.kv_cache = None
target.position = 0
target.init_kv_cache(input_ids)

vanilla_tokens = []
last_token = input_ids[:, -1:]

torch.cuda.synchronize()
start = time.time()

with torch.no_grad():
    for _ in range(MAX_TOKENS):
        logits = target.forward_next(last_token)
        next_token = target.select_tokens(logits)
        vanilla_tokens.append(next_token.item())
        last_token = next_token
        if next_token.item() == tokenizer.eos_token_id:
            break

In [None]:
torch.cuda.synchronize()
vanilla_time = time.time() - start
vanilla_latency = vanilla_time / len(vanilla_tokens) * 1000

print("=== FAIR VANILLA BASELINE ===")
print(f"Total time (s): {vanilla_time:.4f}")
print(f"Tokens generated: {len(vanilla_tokens)}")
print(f"Latency per token (ms): {vanilla_latency:.2f}")
print(f"Throughput (tok/s): {len(vanilla_tokens)/vanilla_time:.2f}")

vanilla_text = tokenizer.decode(vanilla_tokens, skip_special_tokens=True)
print(f"Output: {vanilla_text}")

In [None]:
engine = SpeculativeEngine(
    draft_model=draft,
    target_model=target,
    max_k=4,
    entropy_bins=[1.5, 3.0, 5.0],
    k_values=[4, 3, 2, 1],
)

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
spec_output = engine.decode(input_ids, max_tokens=MAX_TOKENS)
spec_latency = engine.performance_tracker.summary()['latency_per_token_ms']


In [None]:
print("\n=== SPECULATIVE ENGINE ===")
print(engine.performance_tracker.summary())
print(engine.quality_evaluator.summary())
print(f"Output: {tokenizer.decode(spec_output[0], skip_special_tokens=True)}")


In [None]:
plt.figure(figsize=(6, 4))
plt.bar(["Vanilla\n(fair)", "Speculative"], [vanilla_latency, spec_latency])
plt.ylabel("Latency per Token (ms)")
plt.title("Latency Comparison (Fair)")
plt.tight_layout()
plt.show()

print(f"\nSpeedup: {vanilla_latency/spec_latency:.2f}x")