In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

MODEL_PATH  = "antoine-444/m3_sft_2e-6_model"
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE  = 8
SPLITS      = "test"
OUT_DIR     = "/home/eval_out"

TASK_REGISTRY = {
    "MMLU-STEM" : "antoine-444/mmlu_stem_dataset",
    "SciQ"      : "antoine-444/sciq_dataset",
    "AQuA-RAT"  : "antoine-444/aqua_rat_dataset",
    "MedMCQA"   : "antoine-444/medmcqa_dataset",
    "AI2-ARC"   : "antoine-444/ai2_arc_dataset",
}

LETTER_INDICES = ["A", "B", "C", "D"]

# List of STEM subjects in the MMLU STEM split
STEM_SUBJECTS = [
    "abstract_algebra", "anatomy", "astronomy", "biology", "clinical_knowledge",
    "college_biology", "college_chemistry", "college_computer_science", "college_mathematics",
    "college_medicine", "college_physics", "computer_security", "conceptual_physics",
    "electrical_engineering", "elementary_mathematics", "formal_logic", "high_school_biology",
    "high_school_chemistry", "high_school_computer_science", "high_school_mathematics",
    "high_school_physics", "high_school_statistics", "logical_fallacies", "machine_learning",
    "medical_genetics", "nutrition", "professional_medicine", "virology"
]

# Load only the STEM split once
ds = load_dataset("antoine-444/mmlu_stem_dataset", split=SPLITS)

In [None]:
def mcqa_prompt(ex, topic="knowledge and skills in advanced master-level STEM courses"):
    """"
    Generate a multiple-choice question prompt for the given example.
    Args:
        ex (dict): Example from the dataset containing 'question', 'choices', and 'answer'.
        topic (str): Topic of the questions, default is "knowledge and skills in advanced master-level STEM courses".
    Returns:
        str: Formatted prompt string for the multiple-choice question.
    """
    p  = f"The following are multiple choice questions (with answers) about {topic}.\n\n"
    p += ex["question"] + "\n"
    p += "".join(f"{l}. {c}\n" for l, c in zip(LETTER_INDICES, ex["choices"]))
    p += "Answer:"
    return p

def gold_index(ex):
    """"
    Get the 0-based index of the correct answer in the choices.
    Args:
        ex (dict): Example from the dataset containing 'answer'.
    Returns:
        int: Index of the correct answer in the LETTER_INDICES list.
    """
    return LETTER_INDICES.index(ex["answer"])


@torch.inference_mode()
def batched_ll(model, tok, prompts, choice_ids, device):
    """"
    Compute the log likelihood of each choice in the prompts.
    Args:
        model (AutoModelForCausalLM): Pretrained language model.
        tok (AutoTokenizer): Tokenizer for encoding prompts.
        prompts (list of str): List of prompts to evaluate.
        choice_ids (torch.Tensor): Tensor containing token IDs for choices.
        device (str): Device to run the model on ('cpu' or 'cuda').
    Returns:
        torch.Tensor: Log likelihoods of each choice in the prompts.
    """
    enc = tok(prompts, return_tensors="pt", padding=True).to(device)
    logits = model(**enc).logits[:, -1]

    return log_softmax(logits, -1)[:, choice_ids]

def evaluate_ds(model, tok, ds, prompt_fn, batch_size, device):
    """"
    Evaluate the dataset using the provided model and tokenizer.
    Args:
        model (AutoModelForCausalLM): Pretrained language model.
        tok (AutoTokenizer): Tokenizer for encoding prompts.
        ds (Dataset): Dataset to evaluate.
        prompt_fn (function): Function to generate prompts from dataset examples.
        batch_size (int): Batch size for evaluation.
        device (str): Device to run the model on ('cpu' or 'cuda').
    Returns:
        tuple: Two numpy arrays containing gold indices and predicted indices.
    """
    # Precompute token IDs for " A", " B", etc.
    choice_ids = torch.tensor(
        [tok.encode(f" {l}", add_special_tokens=False)[0] for l in LETTER_INDICES],
        device=device
    )
    gold_list, pred_list = [], []

    # Iterate by index and build proper example lists
    for i in range(0, len(ds), batch_size):
        end = min(i + batch_size, len(ds))
        examples = [ds[j] for j in range(i, end)]
        prompts = [prompt_fn(ex) for ex in examples]
        ll = batched_ll(model, tok, prompts, choice_ids, device)
        pred_list.extend(ll.argmax(-1).tolist())
        gold_list.extend(gold_index(ex) for ex in examples)
        
    return np.array(gold_list), np.array(pred_list)

In [None]:
print("🔄 Loading model …")

tok = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="left")
tok.pad_token_id = tok.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    device_map="auto" if DEVICE == "cuda" else None,
).eval()

In [None]:
cms = []

for subj in STEM_SUBJECTS:
    sub_ds = ds.filter(lambda ex: ex["subject"] == subj)
    if len(sub_ds) == 0:
        cms.append(None)
        continue
    g, p = evaluate_ds(model, tok, sub_ds, mcqa_prompt, BATCH_SIZE, DEVICE)
    cm = confusion_matrix(g, p, labels=list(range(len(LETTER_INDICES))))
    cms.append(cm)

# 1) Build a DataFrame of shape (28 × 16)
pairs = [f"{t}→{p}" for t in LETTER_INDICES for p in LETTER_INDICES]
data = []
for cm in cms:
    if cm is None:
        data.append([0]*16)
    else:
        data.append(cm.flatten().tolist())

df = pd.DataFrame(data, index=STEM_SUBJECTS, columns=pairs)

# 2) (Optional) Normalize each row so they sum to 1
df_norm = df.div(df.sum(axis=1), axis=0).fillna(0)

import os

out_path = "eval_out"
os.makedirs(out_path, exist_ok=True)

plt.figure(figsize=(16, 8))
sns.heatmap(
    df_norm, 
    cmap="YlGnBu", 
    cbar_kws={"label": "Proportion"}, 
    xticklabels=pairs, 
    yticklabels=[s.replace("_"," ").title() for s in STEM_SUBJECTS]
)
plt.xticks(rotation=90, ha="center")
plt.title("Normalized True→Pred Confusion by STEM Subject")
plt.xlabel("True → Predicted")
plt.ylabel("Subject")
plt.tight_layout()

# --- SAVE TO FILE ---
save_file = os.path.join(out_path, "combined_true_pred_confusion_heatmap.png")
plt.savefig(save_file, dpi=300, bbox_inches="tight")
print(f"Saved heatmap to {save_file}")

plt.show()