In [None]:
!rm -rf /content/fuse-qa
# Clone repo
!git clone https://github.com/MinaGabriel/fuse-qa.git
!export HF_HUB_ENABLE_HF_TRANSFER=1
%cd /content/fuse-qa
!pip install -q -e .

In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging
from huggingface_hub import login

from fuseqa import *

# ─────────────────────────────────────────────────────────────
# Setup
# ─────────────────────────────────────────────────────────────

logging.set_verbosity_error()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
# MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
# MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
# MODEL_NAME = "google/gemma-2-2b-it"
# MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
MODEL_NAME = "openai/gpt-oss-20b"
RUN_TYPE = "FUSEQA"
USE_CONTEXT = RUN_TYPE in ("FUSEQA", "FUSEQA-SRE")

print(hf_model_to_filename(MODEL_NAME + "-" + RUN_TYPE))
print("GPUs:", torch.cuda.device_count())

# ─────────────────────────────────────────────────────────────
# HuggingFace Auth (optional)
# ─────────────────────────────────────────────────────────────

if (token := os.getenv("HF_TOKEN")):
    login(token=token)
    print("HF login: done")
else:
    print("HF login: skipped")

# ─────────────────────────────────────────────────────────────
# Load Model + Tokenizer
# ─────────────────────────────────────────────────────────────

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
)

model.eval()

# ─────────────────────────────────────────────────────────────
# Device Info
# ─────────────────────────────────────────────────────────────

device = next(model.parameters()).device
print("Model device:", device)

if hasattr(model, "hf_device_map"):
    print("Device map:", model.hf_device_map)



In [None]:
# Loading the dataset
from datasets import load_dataset
ds = load_dataset("MinaGabriel/popqa-with-retrieval-20")["test"].select(range(125))
#ds = load_dataset("MinaGabriel/popqa-with-retrieval-20")["test"]
len(ds)


In [None]:
#REPORTS:
import os
import tqdm
from contextlib import nullcontext

counts  = {g: 0 for g in ("ALL", "LONG-TAIL", "INFREQUENT", "FREQUENT")}
em_hits = {g: 0 for g in ("ALL", "LONG-TAIL", "INFREQUENT", "FREQUENT")}

import time
start_time = time.time()


def update_metrics(tier, em):
    counts["ALL"] += 1
    em_hits["ALL"] += em

    if tier in counts:
        counts[tier] += 1
        em_hits[tier] += em


def current_scores():
    return {
        "ALL_EM":     safe_div(em_hits["ALL"],        counts["ALL"]),
        "Long_Tail":  safe_div(em_hits["LONG-TAIL"],  counts["LONG-TAIL"]),
        "Infrequent": safe_div(em_hits["INFREQUENT"], counts["INFREQUENT"]),
        "Frequent":   safe_div(em_hits["FREQUENT"],   counts["FREQUENT"]),
    }

In [None]:
from dataclasses import dataclass
from typing import Tuple



@dataclass(frozen=True)
class PromptConfig:
    system_no_context: str = (
    "You are a precise factual question answering system.\n"
    "Return only the exact answer span."
    )

    system_with_context: str = (
        "You are a strict answer extraction system.\n"
        "Extract the exact answer span from the context."
    )

    rules: Tuple[str, ...] = (
    "Answer with the shortest possible span (1-3 words).",
    "Do not explain.",
    "Do not explain. Do not repeat the question.",
)   


In [None]:

TOP_K = 3

RESULTS_DIR = "results"
os.makedirs(RESULTS_DIR, exist_ok=True)

file_name = hf_model_to_filename(MODEL_NAME + "-" + RUN_TYPE)
outfile   = os.path.join(RESULTS_DIR, file_name + ".jsonl")

model_device = next(model.parameters()).device
WRITE_OUTPUTS = True  # or True if you want JSONL output

with (open(outfile, "w", encoding="utf-8", buffering=1) if WRITE_OUTPUTS else nullcontext()) as writer:

    pbar = tqdm.tqdm(total=len(ds), desc="Generating + Evaluating", dynamic_ncols=True)

    for i in range(len(ds)):

        ex = {k: ds[k][i] for k in ds.column_names}

        q, s_pop = ex["question"], int(ex.get("s_pop", 0))
        tier = tier_from_spop(s_pop)

        gold = parse_list(ex.get("possible_answers"))
        gold_norm_set = {norm(g) for g in gold if norm(g)}

        retrieved = ex.get("retrieved_docs") or []
        context = build_context(retrieved, k=TOP_K) if USE_CONTEXT else ""

        pred = ask_llm_generate(prompt_cfg = PromptConfig(), model=model, tokenizer=tokenizer, question=q, context=context, use_context=USE_CONTEXT, device=model_device, print_prompt=True)

        pred_norm = norm(pred)
        em = int(pred_norm in gold_norm_set) if gold_norm_set else 0

        update_metrics(tier, em)

        if WRITE_OUTPUTS:
            record = {"i": i, "s_pop": s_pop, "tier": tier, "question": q, "gold": gold, "pred": pred, "em": em}
            write_record(writer, record)

        pbar.update(1)

        if i % 10 == 0:
            pbar.set_postfix({k: f"{v:.4f}" for k, v in current_scores().items()})

    pbar.close()

total_time = time.time() - start_time

generate_report(counts, em_hits, file_name, model_name=MODEL_NAME, run_type=RUN_TYPE, total_time=total_time)


In [None]:
! ls -lh results/*.jsonl
! cat results/*.jsonl | head -3