In [None]:
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from core.speculative_engine import SpeculativeEngine
from core.draft_model import DraftModel
from core.target_model import TargetModel


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
draft = DraftModel(tokenizer,"TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device)
target = TargetModel(tokenizer,"meta-llama/Llama-2-7b-hf", device=device)

engine = SpeculativeEngine(
    draft_model=draft,
    target_model=target,
    max_k=4,
    entropy_bins=[0.5,1.5,2.5],
    k_values=[ 4,3, 2, 1],
)


In [None]:
import time

prompt = "The theory of evolution explains"
input_ids = target.tokenizer(
    prompt, return_tensors="pt"
).input_ids.to(device)

torch.cuda.synchronize()
start = time.time()

with torch.no_grad():
    outputs = target.model.generate(
        input_ids,
        max_new_tokens=50,
        do_sample=False,
        use_cache=True,
    )

torch.cuda.synchronize()
end = time.time()

total_time = end - start
tokens_generated = outputs.shape[1] - input_ids.shape[1]

latency_per_token_ms = (total_time / tokens_generated) * 1000
throughput = tokens_generated / total_time

print("=== VANILLA BASELINE ===")
print(f"Total time (s): {total_time:.4f}")
print(f"Tokens generated: {tokens_generated}")
print(f"Latency per token (ms): {latency_per_token_ms:.2f}")
print(f"Throughput (tok/s): {throughput:.2f}")


In [None]:
spec_output = engine.decode(input_ids, max_tokens=50)
spec_latency = engine.performance_tracker.latency_per_token_ms


In [None]:
plt.bar(
    ["Vanilla", "Speculative"],
    [latency_per_token_ms, spec_latency]
)
plt.ylabel("Latency per Token (ms)")
plt.title("Latency Comparison")
plt.show()


In [None]:
print("Speculative target calls:",
      engine.performance_tracker.target_forward_calls)


In [None]:
# Decode texts
vanilla_text = target.tokenizer.decode(
    outputs[0], skip_special_tokens=True
)
spec_text = target.tokenizer.decode(
    spec_output[0], skip_special_tokens=True
)

print("\n=== VANILLA OUTPUT ===")
print(vanilla_text)

print("\n=== SPECULATIVE OUTPUT ===")
print(spec_text)
