In [None]:
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

from adaptation.entropy_calculator import EntropyCalculator


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "TinyLlama/TinyLlama-1.1B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16
).to(device)
model.eval()

entropy_calc = EntropyCalculator()


In [None]:
prompts = {
    "easy": "The capital of France is",
    "medium": "The theory of evolution explains",
    "hard": "In a universe governed by probabilistic causality"
}


In [None]:
def compute_entropy_trajectory(prompt, steps=20):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    entropies = []

    with torch.no_grad():
        outputs = model(input_ids, use_cache=True, return_dict=True)
        kv_cache = outputs.past_key_values
        logits = outputs.logits[:, -1, :]

        for _ in range(steps):
            entropy = entropy_calc.compute(logits).item()
            entropies.append(entropy)

            token = torch.argmax(logits, dim=-1, keepdim=True)
            outputs = model(
                token,
                past_key_values=kv_cache,
                use_cache=True,
                return_dict=True
            )
            kv_cache = outputs.past_key_values
            logits = outputs.logits[:, -1, :]

    return entropies


In [None]:
plt.figure(figsize=(8, 4))

for name, prompt in prompts.items():
    ent = compute_entropy_trajectory(prompt)
    plt.plot(ent, label=name)

plt.xlabel("Decoding Step")
plt.ylabel("Entropy")
plt.title("Entropy over Decoding Steps")
plt.legend()
plt.grid(True)
plt.show()
