In [1]:
import sys
import os
import torch
import tqdm
import json
from fuseqa import *
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from transformers.utils import logging
logging.set_verbosity_error()
from datetime import datetime

now = datetime.now().strftime("%Y-%m-%d %H:%M")
# runr> export HF_TOKEN=your_token_here
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
#MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
MODEL_NAME = "mistralai/Mistral-Nemo-Instruct-2407"

TOP_K = 10


# Types: 
#   - PARAMETRIC:
#   - FUSEQA:
#   - FUSEQA-SRE:  
RUN_TYPE = "FUSEQA"
USE_CONTEXT = RUN_TYPE in ("FUSEQA", "FUSEQA-SRE")
print(hf_model_to_filename(MODEL_NAME + "-" + RUN_TYPE))
print("GPUs:", torch.cuda.device_count())

token = os.getenv("HF_TOKEN")
if token:
    login(token=token)
    print("Logged in to Hugging Face.")
else:
    print("HF_TOKEN not set.")


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


mistralai_Mistral-Nemo-Instruct-2407-FUSEQA_20260218-1314
GPUs: 1
Logged in to Hugging Face.


In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)

print("Model device:", next(model.parameters()).device)

if hasattr(model, "hf_device_map"):
    print("Device map:", model.hf_device_map)
else:
    print("Single-device load (no map needed)")

config.json:   0%|          | 0.00/622 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [None]:
# Loading the dataset
from datasets import load_dataset
# ds = load_dataset("MinaGabriel/popqa-with-retrieval-20")["test"].select(range(25))
ds = load_dataset("MinaGabriel/popqa-with-retrieval-20")["test"]
len(ds)

In [None]:
groups = ["ALL", "LONG-TAIL", "INFREQUENT", "FREQUENT"]

counts  = {g: 0 for g in groups}
em_hits = {g: 0 for g in groups}


def update_metrics(tier, em):
    for grp in ("ALL", tier):
        counts[grp]  += 1
        em_hits[grp] += em


def current_scores():
    return {
        "ALL_EM":     safe_div(em_hits["ALL"],        counts["ALL"]),
        "Long_Tail":  safe_div(em_hits["LONG-TAIL"],  counts["LONG-TAIL"]),
        "Infrequent": safe_div(em_hits["INFREQUENT"], counts["INFREQUENT"]),
        "Frequent":   safe_div(em_hits["FREQUENT"],   counts["FREQUENT"]),
    }


progress_bar = tqdm.tqdm(
    total=len(ds["question"]),
    desc="Generating + Evaluating",
    dynamic_ncols=True
)


outfile = hf_model_to_filename(MODEL_NAME + "-" + RUN_TYPE) + ".jsonl"

with open(outfile, "w", encoding="utf-8", buffering=1) as f:

    num_examples = len(ds["question"])

    for i in range(num_examples):

        ex = {k: ds[k][i] for k in ds.column_names}

        q      = ex["question"]
        s_pop  = int(ex.get("s_pop", 0))
        tier   = tier_from_spop(s_pop)

        gold            = parse_list(ex.get("possible_answers"))
        gold_norm_set   = {norm(g) for g in gold if norm(g)}

        retrieved = ex.get("retrieved_docs") or []
        context   = build_context(retrieved, k=10) if USE_CONTEXT else ""

        pred = ask_llm_generate(
            model,
            tokenizer,
            q,
            context,
            use_context=USE_CONTEXT,
            llama_device="cuda:0"
        )

        pred_norm = norm(pred)
        em        = int(pred_norm in gold_norm_set) if gold_norm_set else 0

        update_metrics(tier, em)

        record = {
            "i":        i,
            "s_pop":    s_pop,
            "tier":     tier,
            "question": q,
            "gold":     gold,
            "pred":     pred,
            "em":       em,
        }

        f.write(json.dumps(record, ensure_ascii=False) + "\n")

        progress_bar.update(1)

        scores = current_scores()
        progress_bar.set_postfix({k: f"{v:.4f}" for k, v in scores.items()})


progress_bar.close()


# =========================
# REPORT
# =========================

lines = []

lines.append("=" * 80)
lines.append("PopQA Exact Match Report â€” Tiers (Long-Tail / Infrequent / Frequent)")
lines.append("=" * 80)

lines.append(f"Model: {MODEL_NAME}")
lines.append(f"Mode: {RUN_TYPE}")
lines.append(f"Run Time: {now}")
lines.append(f"Slice: [0,{len(ds)}) | n={len(ds)}")

lines.append("")
lines.append("Tier Definitions (by s_pop):")
lines.append("- LONG-TAIL:  s_pop < 100")
lines.append("- INFREQUENT: 100 <= s_pop < 10000")
lines.append("- FREQUENT:   s_pop >= 10000")

lines.append("")
lines.append("Exact Match (EM):")

for name in groups:
    n = counts[name]
    lines.append(f"- {name:<10} n={n:<6} EM={safe_div(em_hits[name], n):.4f}")


report_file = hf_model_to_filename(MODEL_NAME + "-" + RUN_TYPE) + ".report.txt"

with open(report_file, "w", encoding="utf-8") as f:
    f.write("\n".join(lines))


print(
    f"Saved report: {report_file} | "
    f"LT EM={safe_div(em_hits['LONG-TAIL'], counts['LONG-TAIL']):.4f} | "
    f"INF EM={safe_div(em_hits['INFREQUENT'], counts['INFREQUENT']):.4f} | "
    f"FREQ EM={safe_div(em_hits['FREQUENT'], counts['FREQUENT']):.4f}"
)
