In [None]:
import torch
import matplotlib.pyplot as plt

from core.speculative_engine import SpeculativeEngine
from core.draft_model import DraftModel
from core.target_model import TargetModel


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

draft = DraftModel("TinyLlama/TinyLlama-1.1B", device=device)
target = TargetModel("meta-llama/Llama-2-7b", device=device)

engine = SpeculativeEngine(
    draft_model=draft,
    target_model=target,
    max_k=8,
    entropy_bins=[1.2, 2.2, 3.0],
    k_values=[8, 4, 2, 0],
)


In [None]:
prompt = "The theory of evolution explains"
input_ids = target.tokenizer(
    prompt, return_tensors="pt"
).input_ids.to(device)

# Vanilla decoding
target.performance_tracker.reset()
target.performance_tracker.start()

outputs = target.model.generate(
    input_ids,
    max_new_tokens=50,
    do_sample=False,
)

target.performance_tracker.stop()
baseline_latency = target.performance_tracker.latency_per_token_ms


In [None]:
spec_output = engine.decode(input_ids, max_tokens=50)
spec_latency = engine.performance_tracker.latency_per_token_ms


In [None]:
plt.bar(
    ["Vanilla", "Speculative"],
    [baseline_latency, spec_latency]
)
plt.ylabel("Latency per Token (ms)")
plt.title("Latency Comparison")
plt.show()


In [None]:
print("Speculative target calls:",
      engine.performance_tracker.target_forward_calls)


In [None]:
vanilla_text = target.tokenizer.decode(
    outputs[0], skip_special_tokens=True
)
spec_text = target.tokenizer.decode(
    spec_output[0], skip_special_tokens=True
)

assert vanilla_text == spec_text
print("âœ… Output equivalence verified")
