In [None]:
# ! git clone https://github.com/MinaGabriel/fuse-qa.git
# %cd /content/fuse-qa
# export HF_TOKEN=
# ! pip install -e . 


In [1]:

import sys
import os
import torch
import tqdm
import json
from fuseqa import *
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from transformers.utils import logging
logging.set_verbosity_error()
from datetime import datetime

now = datetime.now().strftime("%Y-%m-%d %H:%M")
# runr> export HF_TOKEN=your_token_here
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
#MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
#MODEL_NAME = "google/gemma-2-2b-it" 
#MODEL_NAME = "openai/gpt-oss-20b"
# Types: 
#   - PARAMETRIC:
#   - FUSEQA:
#   - FUSEQA-SRE:  
RUN_TYPE = "FUSEQA"
USE_CONTEXT = RUN_TYPE in ("FUSEQA", "FUSEQA-SRE")
print(hf_model_to_filename(MODEL_NAME + "-" + RUN_TYPE))
print("GPUs:", torch.cuda.device_count())

token = os.getenv("HF_TOKEN")
if token:
    login(token=token)
    print("Logged in to Hugging Face.")
else:
    print("HF_TOKEN not set.")



Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


meta-llama_Meta-Llama-3-8B-FUSEQA_20260219-1011
GPUs: 1
Logged in to Hugging Face.


In [2]:

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)

model.eval()
print("Model device:", next(model.parameters()).device)

if hasattr(model, "hf_device_map"):
    print("Device map:", model.hf_device_map)
else:
    print("Single-device load (no map needed)")

Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model device: cuda:0
Single-device load (no map needed)


In [9]:
# Loading the dataset
from datasets import load_dataset
ds = load_dataset("MinaGabriel/popqa-with-retrieval-20")["test"].select(range(125))
ds = load_dataset("MinaGabriel/popqa-with-retrieval-20")["test"]
len(ds)

TOP_K = 3

In [10]:
from contextlib import nullcontext

import time

def run_popqa_eval(
    model,
    tokenizer,
    ds,
    write_outputs=False,
    groups=("ALL", "LONG-TAIL", "INFREQUENT", "FREQUENT"),
):
    start_time = time.time()  #START TIMER

    counts  = {g: 0 for g in groups}
    em_hits = {g: 0 for g in groups}

    def update_metrics(tier, em):
        counts["ALL"] += 1
        em_hits["ALL"] += em

        if tier in counts:
            counts[tier] += 1
            em_hits[tier] += em

    def current_scores():
        return {
            "ALL_EM":     safe_div(em_hits["ALL"],        counts["ALL"]),
            "Long_Tail":  safe_div(em_hits["LONG-TAIL"],  counts["LONG-TAIL"]),
            "Infrequent": safe_div(em_hits["INFREQUENT"], counts["INFREQUENT"]),
            "Frequent":   safe_div(em_hits["FREQUENT"],   counts["FREQUENT"]),
        }

    file_name = hf_model_to_filename(MODEL_NAME + "-" + RUN_TYPE)
    outfile   = file_name + ".jsonl"

    context_manager = open(outfile, "w", encoding="utf-8", buffering=1) if write_outputs else nullcontext()

    progress_bar = tqdm.tqdm(
        total=len(ds),
        desc="Generating + Evaluating",
        dynamic_ncols=True,
    )

    model_device = next(model.parameters()).device

    with context_manager as writer:
        for i in range(len(ds)):

            ex = {k: ds[k][i] for k in ds.column_names}

            q      = ex["question"]
            s_pop  = int(ex.get("s_pop", 0))
            tier   = tier_from_spop(s_pop)

            gold = parse_list(ex.get("possible_answers"))
            gold_norm_set = {norm(g) for g in gold if norm(g)}

            retrieved = ex.get("retrieved_docs") or []
            context   = build_context(retrieved, k=TOP_K) if USE_CONTEXT else ""

            pred = ask_llm_generate(
                model,
                tokenizer,
                q,
                context,
                use_context=USE_CONTEXT,
                device=model_device
            )

            pred_norm = norm(pred)
            em = int(pred_norm in gold_norm_set) if gold_norm_set else 0

            update_metrics(tier, em)

            if write_outputs:
                record = {
                    "i": i,
                    "s_pop": s_pop,
                    "tier": tier,
                    "question": q,
                    "gold": gold,
                    "pred": pred,
                    "em": em,
                }
                write_record(writer, record)

            progress_bar.update(1)

            if i % 10 == 0:
                progress_bar.set_postfix({k: f"{v:.4f}" for k, v in current_scores().items()})

    progress_bar.close()

    total_time = time.time() - start_time  #END TIMER

    return counts, em_hits, file_name, total_time


# RUN
counts, em_hits, file_name, total_time = run_popqa_eval(
    model,
    tokenizer,
    ds,
    write_outputs=False
)

generate_report(
    counts,
    em_hits,
    file_name,
    model_name=MODEL_NAME,
    run_type=RUN_TYPE,
    total_time=total_time,
)


Generating + Evaluating: 100%|██████████| 14267/14267 [55:24<00:00,  4.29it/s, ALL_EM=0.4669, Long_Tail=0.4682, Infrequent=0.4271, Frequent=0.6087]

Saved report: meta-llama_Meta-Llama-3-8B-FUSEQA_20260219-1014.report.txt | Time=3324.05s | ALL=0.4670 | LONG-TAIL=0.4682 | INFREQUENT=0.4272 | FREQUENT=0.6086



