In [None]:

import random, torch, numpy as np, pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM


In [None]:
SEED = 0
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MODELS = [
    "Qwen/Qwen2.5-7B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
    "mistralai/Ministral-8B-Instruct-2410",
]

# model prompt for multiple-choice questions
FEW_SHOT_MSG = (
    "You will be asked multiple-choice questions. "
    "Each question has **exactly one correct answer**.\n"
    "Return **only** the letter (A, B, C or D) of the option you believe is correct—"
    "no explanation, no punctuation, nothing else.\n\n"

    "Example 1:\n"
    "Question: What is the capital of France?\n"
    "Choices:\n"
    "A. Berlin\nB. Madrid\nC. Paris\nD. Rome\n"
    "FINAL ANSWER: C\n\n"

    "Example 2:\n"
    "Question: What gas do plants primarily use for photosynthesis?\n"
    "Choices:\n"
    "A. Nitrogen\nB. Oxygen\nC. Hydrogen\nD. Carbon Dioxide\n"
    "FINAL ANSWER: D\n\n"
)

# build prompt with chat template
def build_chat_prompt(tok: AutoTokenizer, question: str, choices: list[str]) -> str:
    user_msg = (
        f"{FEW_SHOT_MSG}\n\n"
        f"Question: {question}\nChoices:\n" +
        "\n".join(f"{chr(65+i)}. {c}" for i, c in enumerate(choices)) +
        "\nFINAL ANSWER:"
    )
    messages = [
        {"role": "system", "content": "You are a careful, factual assistant."},
        {"role": "user",   "content": user_msg},
    ]
    return tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

# identify token IDs corresponding to each answer letter
def collect_label_token_ids(letter: str, tok: AutoTokenizer) -> list[int]:
    ids = []
    for prefix in (" ", "\n", ""):
        variant = prefix + letter
        if len(tok.tokenize(variant)) == 1:
            ids.append(tok(variant, add_special_tokens=False).input_ids[0])
    if not ids:
        raise ValueError(f"No single-token rendition for letter '{letter}'")
    return ids

# load mc data
df_truthful = load_dataset("EleutherAI/truthful_qa_mc", split="validation").to_pandas()

records = []                        

# iterate over models
for model_name in MODELS:
    print(f"\n=== {model_name} ===")

    # load model and tokenizer
    tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    mdl = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    ).eval()


    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    if getattr(mdl.config, "pad_token_id", None) is None:
        mdl.config.pad_token_id = tok.pad_token_id

    
    label_ids = {l: collect_label_token_ids(l, tok) for l in "ABCD"}

    prompts, gold_letters, qids = [], [], []
    for qid, row in df_truthful.iterrows():
        question     = row["question"]
        all_choices  = row["choices"]
        correct_idx  = row["label"]
        correct_ans  = all_choices[correct_idx]
        incorrects   = [a for i, a in enumerate(all_choices) if i != correct_idx]

        # shuffle positions of correct answer 
        for pos in range(4):        
            distractors = random.sample(incorrects, k=min(3, len(incorrects)))
            if len(distractors) < 3:
                distractors += ["(None)"] * (3 - len(distractors))

            choices = distractors[:]
            choices.insert(pos, correct_ans)
            prompts.append(build_chat_prompt(tok, question, choices))
            gold_letters.append(chr(65 + pos))
            qids.append(qid)

    # evaluate model reponses in batches
    BATCH = 8
    for start in tqdm(range(0, len(prompts), BATCH), desc=model_name):
        batch_prompts  = prompts[start:start+BATCH]
        batch_correct  = gold_letters[start:start+BATCH]
        batch_qid      = qids[start:start+BATCH]

        enc = tok(batch_prompts, return_tensors="pt",
                  padding=True, truncation=True).to(DEVICE)

        with torch.no_grad():
            logits = mdl(**enc).logits        

        # get logits for final token prediction
        T = enc["attention_mask"].size(1)
        last_idx = (T - 1) - enc["attention_mask"].flip(-1).argmax(-1)

        batch_logits = logits[torch.arange(logits.size(0), device=logits.device), last_idx]

        # record predictions and log probabilities
        for j in range(len(batch_prompts)):
            lp = torch.nn.functional.log_softmax(batch_logits[j], dim=0)

            option_lp = {l: torch.logsumexp(lp[ids], dim=0).item()
                         for l, ids in label_ids.items()}
            pred = max(option_lp, key=option_lp.get)

            records.append({
                "qid":   batch_qid[j],
                "model": model_name,
                "gold":  batch_correct[j],
                "pred":  pred,
                **option_lp
            })

    del mdl; torch.cuda.empty_cache()

# save reults
df_eval = pd.DataFrame(records)
df_eval["correct"] = df_eval["gold"] == df_eval["pred"]
df_eval.to_csv("pos_bias.csv", index=False)
