In [None]:
import re
from typing import List, Dict, Optional, Tuple, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np

# --------------------------
# Prompting & parsing
# --------------------------
LETTERS = [chr(ord('A') + i) for i in range(26)]

def format_options(options: List[str]) -> str:
    lines = []
    for i, opt in enumerate(options):
        lines.append(f"{LETTERS[i]}. {opt}")
    return "\n".join(lines)

def build_prompt(question: str, options: List[str]) -> str:
    letters = ", ".join(LETTERS[:len(options)])
    return (
        "You are given a multiple-choice question.\n"
        "Choose the correct option and answer with a single CAPITAL LETTER only.\n"
        f"Valid answers: {{{letters}}}\n\n"
        f"Question:\n{question}\n\n"
        f"Options:\n{format_options(options)}\n\n"
        "Answer:"
    )

# Prefer word-boundary matches; avoid picking 'A' in 'Answer'
LETTER_TOKEN_RE = re.compile(r"\b([A-Z])\b")
PAREN_LETTER_RE = re.compile(r"\(([A-Z])\)")

def parse_letter(text: str, k: int) -> Optional[str]:
    """
    Extract the chosen letter from model output. Only return a letter in {A.. up to k}.
    Order of attempts:
      1) exact single-letter (strict)
      2) any standalone single-letter token (word boundaries)
      3) common '(C)' parenthetical pattern
    """
    text = text.strip()

    # 1) strict
    m = re.match(r"^\s*([A-Z])\s*$", text)
    if m:
        ch = m.group(1)
        if 0 <= (ord(ch) - 65) < k:
            return ch

    # 2) standalone tokens
    tokens = LETTER_TOKEN_RE.findall(text)
    for ch in tokens:
        if 0 <= (ord(ch) - 65) < k:
            return ch

    # 3) parenthetical like '(C)'
    m = PAREN_LETTER_RE.search(text)
    if m:
        ch = m.group(1)
        if 0 <= (ord(ch) - 65) < k:
            return ch

    return None

def gold_letter_from_row(row: Dict[str, Any]) -> str:
    if "answer_index" in row and row["answer_index"] is not None:
        return LETTERS[int(row["answer_index"])]
    return str(row["answer"]).strip().upper()

# --------------------------
# Data loading
# --------------------------
import json

def load_jsonl(path: str) -> List[Dict[str, Any]]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows

def load_dataset(path: str) -> List[Dict[str, Any]]:
    if path.endswith(".jsonl"):
        return load_jsonl(path)
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    if isinstance(obj, list):
        return obj
    raise ValueError("JSON must be a list of rows for .json")

# --------------------------
# Gen helpers (device-safe)
# --------------------------
def _is_sharded(model: torch.nn.Module) -> bool:
    # Present when loaded with accelerate/device_map="auto"
    return getattr(model, "hf_device_map", None) is not None

def _primary_device(model: torch.nn.Module) -> torch.device:
    try:
        return next(model.parameters()).device
    except StopIteration:
        return torch.device("cpu")

@torch.inference_mode()
def generate_letters(
    model, tokenizer, prompts: List[str],
    max_new_tokens: int = 4,
    temperature: float = 0.0,
    top_p: float = 1.0,
    batch_size: int = 8,
) -> List[str]:
    """
    Device-safe batching:
      - If model is sharded (device_map="auto"), DO NOT move batches to a single device.
      - Else, move the batch to the modelâ€™s primary device.
    """
    out_texts: List[str] = []
    sharded = _is_sharded(model)
    primary = _primary_device(model)

    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i+batch_size]
        toks = tokenizer(batch, padding=True, truncation=True, return_tensors="pt")
        if not sharded:
            toks = {k: v.to(primary) for k, v in toks.items()}

        gen = model.generate(
            **toks,
            max_new_tokens=max_new_tokens,
            do_sample=(temperature > 0.0),
            temperature=max(temperature, 1e-8),
            top_p=top_p,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
        )
        # Slice out the newly generated tokens only
        new_tokens = gen[:, toks["input_ids"].shape[1]:]
        texts = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
        out_texts.extend([t.strip() for t in texts])

    return out_texts

# --------------------------
# Evaluation (revised)
# --------------------------
def evaluate(
    model_name: str,                  # path to local dir OR HF id
    data_path: str,
    max_samples: int = -1,
    batch_size: int = 8,
    max_new_tokens: int = 4,
    temperature: float = 0.0,
    top_p: float = 1.0,
    local_files_only: bool = False,
) -> Dict[str, Any]:

    # Tokenizer: left padding for decoder-only generation
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, local_files_only=local_files_only)
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Model (single GPU/CPU or sharded)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
        device_map="auto" if torch.cuda.is_available() else None,
        local_files_only=local_files_only,
    ).eval()
    model.config.pad_token_id = tokenizer.pad_token_id

    # Data
    rows = load_dataset(data_path)
    if max_samples > 0:
        rows = rows[:max_samples]

    prompts, k_list, gold_letters, subjects = [], [], [], []
    for row in rows:
        question = row["question"]
        options  = row["options"]
        k = len(options)
        prompts.append(build_prompt(question, options))
        k_list.append(k)
        gold_letters.append(gold_letter_from_row(row))
        subjects.append(row.get("subject", None))

    # Generation
    gens = generate_letters(
        model, tokenizer, prompts,
        max_new_tokens=max_new_tokens,
        temperature=temperature, top_p=top_p,
        batch_size=batch_size,
    )

    # Parse predictions
    preds: List[Optional[str]] = [parse_letter(txt, k) for txt, k in zip(gens, k_list)]

    # Accuracy
    correct_flags = [(p == g) for p, g in zip(preds, gold_letters)]
    overall = sum(bool(x) for x in correct_flags) / max(1, len(correct_flags))

    # Per-subject accuracy (if subject provided)
    per_subject_hits: Dict[str, Tuple[int, int]] = {}
    for ok, subj in zip(correct_flags, subjects):
        if subj is None:
            continue
        hit, tot = per_subject_hits.get(subj, (0, 0))
        per_subject_hits[subj] = (hit + int(ok), tot + 1)
    per_subject_acc = {k: (h / t) for k, (h, t) in per_subject_hits.items() if t > 0}

    return {
        "overall_accuracy": overall,
        "n_examples": len(rows),
        "per_subject_accuracy": per_subject_acc,
        "predictions": preds,
        "gold": gold_letters,
        "raw_generations": gens,
    }

res = evaluate(
    model_name="./llama8b_mmlu_dpop_id/checkpoint-2407",  # local folder with model + tokenizer
    data_path="mmlu_pro_test.jsonl",
    max_samples=-1,
    batch_size=8,
    max_new_tokens=4,
    temperature=0.0,
    top_p=1.0,
    local_files_only=True,
)

In [None]:
print(f"\nOverall accuracy (win-rate): {res['overall_accuracy']*100:.2f}%  on N={res['n_examples']}")
if res["per_subject_accuracy"]:
    print("\nPer-subject accuracy:")
    for k in sorted(res["per_subject_accuracy"].keys()):
        print(f"  {k}: {res['per_subject_accuracy'][k]*100:.2f}%")