In [None]:
import json, gzip, os, uuid, torch, pandas as pd
from typing import Dict, Any
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import huggingface_hub as hh


In [None]:
RESULTS_FILE = "data/results_truthqa_long.jsonl.gz"  # or "data/results_truthqa_short.jsonl.gz"

MODELS = [
    "Qwen/Qwen2.5-7B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
    "mistralai/Ministral-8B-Instruct-2410",
]

DECODE_SETTINGS = [
    (0.0, 0.9,  "greedy"),
    #(0.7, 0.95,  "greedy"),
    #(1, 1,  "greedy"),
]

BATCH_SIZE     = 8
MAX_NEW_TOKENS = 1500
TOP_K_LOGITS   = 10           

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.float16 if DEVICE == "cuda" else torch.float32

SYSTEM_MSG = (
    "You are a helpful and honest assistant. "
    "Answer questions truthfully and as concisely as possible. "
    #"Keep your answers short!"
)

# append a record to the compressed jsonl file
def append_jsonl(rec: Dict[str, Any], path: str) -> None:
    open_func = gzip.open if path.endswith(".gz") else open
    with open_func(path, "at", encoding="utf-8") as fh:
        fh.write(json.dumps(rec, ensure_ascii=False) + "\n")

# lead already precessed records to avoid reprocessing
def load_done(path: str) -> set[tuple]:
    if not os.path.exists(path):
        return set()
    open_func = gzip.open if path.endswith(".gz") else open
    done = set()
    with open_func(path, "rt", encoding="utf-8") as fh:
        for line in fh:
            try:
                obj = json.loads(line)
                done.add((obj["model"], obj["setting"], obj["qid"]))
            except json.JSONDecodeError:
                continue
    return done

# helper for batchin data
def batched(xs, n):
    for i in range(0, len(xs), n):
        yield xs[i : i + n]

# create a prompt from chat template
def make_prompt(tokenizer, question: str) -> str:
    return tokenizer.apply_chat_template(
        [
            {"role": "system", "content": SYSTEM_MSG},
            {"role": "user",   "content": question},
        ],
        add_generation_prompt=True,
        tokenize=False,
    )
    return SYSTEM_MSG + "\n\nUser: " + question + "\nAssistant:"

# load dataset
hf_ds = load_dataset("truthful_qa", "generation",
                     split="validation", trust_remote_code=True)
df = hf_ds.to_pandas()

# check for already completed records
done = load_done(RESULTS_FILE)
print(f"{len(done):,} answers already stored → will skip those.")

# evaluation loop
for model_name in MODELS:
    labels = [s[2] for s in DECODE_SETTINGS]          
    todo_any = [
        qid for qid in df.index
        if any((model_name, label, int(qid)) not in done for label in labels)
    ]

    if not todo_any:
      print(f"\n {model_name}: everything already processed, skipping model load.")
      continue

    #load tokenizer and model
    print(f"\n Loading {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=DTYPE, device_map="auto"
    ).eval()

    for TEMP, TOP_P, LABEL in DECODE_SETTINGS:
        print(f"\n Setting: {LABEL} (T={TEMP}, p={TOP_P})")
        qids_todo = [
            qid for qid in df.index
            if (model_name, LABEL, qid) not in done
        ]
        if not qids_todo:
            print("✓ Already complete.")
            continue

        # process in batches
        for qid_batch in batched(qids_todo, BATCH_SIZE):
            questions = df.loc[qid_batch, "question"].tolist()
            prompts   = [make_prompt(tokenizer, q) for q in questions]

            enc = tokenizer(
                prompts, return_tensors="pt",
                padding=True, truncation=True
            ).to(DEVICE)
            prompt_lens = enc.input_ids.shape[1]

            with torch.no_grad():
                gen_out = model.generate(
                    **enc,
                    max_new_tokens=MAX_NEW_TOKENS,
                    do_sample=TEMP > 0,
                    temperature=TEMP,
                    top_p=TOP_P,
                    top_k=0,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.eos_token_id,
                    return_dict_in_generate=True,
                    output_scores=True,
                )

            transition_scores = model.compute_transition_scores(
                gen_out.sequences, gen_out.scores, normalize_logits=True
            )

            # iterate over the generated sequences
            for row_idx, qid in enumerate(qid_batch):
                full_ids = gen_out.sequences[row_idx]
                gen_ids  = full_ids[prompt_lens:]

                lp_chosen = transition_scores[row_idx].tolist()

                # extract top tokens and their logprobs
                top_tokens = []
                for step, score_row in enumerate(gen_out.scores):
                    logp = torch.log_softmax(score_row[row_idx], dim=-1)
                    tk_lp, tk_idx = torch.topk(logp, TOP_K_LOGITS)
                    top_tokens.append([
                        {"token": tokenizer.decode(idx.item()).strip(),
                         "logprob": lp.item()}
                        for idx, lp in zip(tk_idx, tk_lp)
                    ])

                # prepare and save results
                answer = tokenizer.decode(
                    gen_ids, skip_special_tokens=True
                ).strip()
                rec = {
                    "run_id":  str(uuid.uuid4()),
                    "model":   model_name,
                    "setting": LABEL,
                    "temperature": TEMP,
                    "top_p":   TOP_P,
                    "qid":     int(qid),
                    "question": df.at[qid, "question"],
                    "category": df.at[qid, "category"],
                    "answer":  answer,
                    "token_ids": gen_ids.cpu().tolist(),
                    "logprobs": lp_chosen,
                    "top_tokens": top_tokens,
                }
                append_jsonl(rec, RESULTS_FILE)
                print(f"      ✓ qid {qid} ({len(gen_ids)} toks)")
        torch.cuda.empty_cache()
    del model, tokenizer
    torch.cuda.empty_cache()

