In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.utils import logging as hf_logging
from datasets import load_dataset, Dataset, DatasetDict
# Standard library imports
import os
import sys
import re
import json
import math
import time
import random
import subprocess
import unicodedata

# Scientific and numerical computing
import numpy as np
import torch
import torch.nn.functional as F

# Hugging Face and NLP libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from huggingface_hub import snapshot_download

# Pyserini (for sparse retrieval)
from pyserini.search.lucene import LuceneSearcher
from pyserini.index.lucene import LuceneIndexReader

# Utilities
from tqdm import tqdm

import re
import string
import ast
from typing import Union, List, Set

hf_logging.set_verbosity_error()  # mute HF warnings

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_name = "Qwen/Qwen2.5-7B-Instruct"
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=False,
)

print(f"Loading {model_name} in 8-bit on cuda:0 (RTX 3090)...")

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map={"": 0},          # <<< use the Ampere GPU
    attn_implementation="eager", # extra safety: disable FlashAttention usage path
)

# Make sure we have a pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Model loaded successfully!")
print(f"Device map: {model.hf_device_map}")


In [None]:
def build_candidates_prompt(question: str, snippets: list) -> str:
    """
    Build a prompt that asks the model to return ALL possible answers
    as a Python list of strings, based only on the snippets.
    """
    # Turn snippets into a simple bullet list block
    snippets_text = ""
    for snippet in snippets:
        contents = snippet.get("contents", "").strip()
        snippets_text += f"- {contents}\n"

    return (
        "You are a question-answer extraction module.\n"
        "You are given a question and several text snippets.\n"
        "ASSUME at least one snippet is relevant to the question.\n"
        "Use ONLY the snippets (ignore all outside knowledge).\n"
        "\n"
        "TASK:\n"
        "- Extract ALL possible answers that directly answer the question.\n"
        "- Answers must be short (1–3 words), normalized (lowercase), and taken from snippet wording.\n"
        "- If multiple labels appear (e.g. \"journalist, broadcaster, writer\"), return them all.\n"
        "- If the question refers to a person, extract all occupations/roles that apply.\n"
        "- The output MUST be a Python string array (list[str]).\n"
        "- If no answers exist in snippets, return: [\"unknown\"].\n"
        "\n"
        "OUTPUT FORMAT (EXTREMELY IMPORTANT):\n"
        "Return ONLY a Python-style list of strings. No explanations.\n"
        "Example valid outputs:\n"
        "  [\"politician\", \"soldier\"]\n"
        "  [\"journalist\"]\n"
        "  [\"actor\", \"producer\", \"director\"]\n"
        "  [\"unknown\"]\n"
        "\n"
        f"Question: {question}\n"
        "Snippets:\n"
        f"{snippets_text.strip()}\n"
        "\n"
        "Return a Python list of strings representing ALL valid answers:"
    )
import ast
import torch

def get_candidates(question: str, snippets: list) -> list[str]:
    """
    Call the LLM to extract all candidate answers.
    Returns a Python list of strings: ["politician", "soldier", ...].
    """
    # Build prompt
    prompt = build_candidates_prompt(question, snippets)
    messages = [{"role": "user", "content": prompt}]

    # Tokenize
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )

    # Optional: inspect prompt token length (uncomment if needed)
    # print("Prompt tokens:", inputs["input_ids"].shape[-1])

    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            temperature=0.0,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    # Decode only the generated part (after the prompt)
    input_len = inputs["input_ids"].shape[-1]
    new_ids = outputs[0][input_len:]
    response_text = tokenizer.decode(new_ids, skip_special_tokens=True).strip()

    # Take the first non-empty line as the model's "list" output
    first_line = ""
    for line in response_text.splitlines():
        line = line.strip()
        if line:
            first_line = line
            break

    if not first_line:
        return ["unknown"]

    # Try to parse as a Python list of strings
    try:
        parsed = ast.literal_eval(first_line)
        if isinstance(parsed, list):
            cleaned = []
            for item in parsed:
                if isinstance(item, str):
                    label = item.strip()
                    if label:
                        cleaned.append(label)
            if cleaned:
                return cleaned
    except Exception:
        # If parsing fails, fall back below
        pass

    # Fallback: treat the whole line as a single answer string
    return [first_line.strip().strip('"').strip("'")] or ["unknown"]

In [None]:
import torch
from model_loader import ModelLoader


# Step 1: Load models once
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loader = ModelLoader(device)


In [None]:
def rank_yes_sentences(yes_sents):
    """
    Example ranking function – adjust if you already have one.
    """
    return sorted(
        yes_sents,
        key=lambda s: (
            -s["p_yes"],        # higher p_yes first
            -s["dense_score"],  # then higher dense score
             s.get("dense_rank", 1e9),
             s.get("bm25_rank", 1e9),
        ),
    )


def compose_context_adaptive( scored_results: list, primary_thr: float = 0.8, fallback_thr: float = 0.7, max_sentences: int = 12, ) -> list:
    """
    From scored_results, build a list of snippet dicts:
        [{"id": 1, "contents": "..."},
         {"id": 2, "contents": "..."},
         ...]
    - Filters by p_yes (primary + fallback thresholds).
    - Falls back to top-K by dense_score if needed.
    - Deduplicates by sentence text.
    """
    if not scored_results:
        return []

    # 1) Filter by primary threshold
    yes_sents = [s for s in scored_results if s["p_yes"] >= primary_thr]

    # 2) Fallback threshold if nothing passes primary
    if not yes_sents:
        yes_sents = [s for s in scored_results if s["p_yes"] >= fallback_thr]

    # 3) Final fallback: top-K by dense score if still empty
    if not yes_sents:
        yes_sents = sorted(
            scored_results,
            key=lambda s: -s["dense_score"],
        )[:max_sentences]

    # 4) Rank sentences
    ranked = rank_yes_sentences(yes_sents)

    # 5) Deduplicate by sentence text, build snippets
    seen_sentences = set()
    snippets = []
    for s in ranked:
        sent = s["sentence"].strip()
        if not sent:
            continue
        if sent in seen_sentences:
            continue

        seen_sentences.add(sent)

        snippets.append(
            {
                "id": len(snippets) + 1,   # simple incremental ID
                "contents": sent,
                # you can keep extra metadata if you ever need:
                # "p_yes": s["p_yes"],
                # "dense_score": s["dense_score"],
                # "bm25_rank": s["bm25_rank"],
                # "docid": s["docid"],
            }
        )

        if len(snippets) >= max_sentences:
            break

    return snippets


In [None]:
dataset = load_dataset("akariasai/PopQA")
from sre_rag import SreRAG
sreg = SreRAG(loader)


In [None]:
import logging
import re
import string
import ast
from tqdm.notebook import tqdm
from typing import Union, List, Set

# 1) Configure logger to write to file only, simple format (no timestamps/levels)
logger = logging.getLogger("qa_logger")
logger.setLevel(logging.INFO)

# Remove existing handlers (if any)
for h in list(logger.handlers):
    logger.removeHandler(h)

file_handler = logging.FileHandler("qa_results.log", mode="w", encoding="utf-8")
formatter = logging.Formatter("%(message)s")  # only message
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

# Prevent propagation to root logger (nothing goes to console)
logger.propagate = False


def normalize_text(s: str) -> str:
    if s is None:
        return ""
    s = s.lower()
    s = re.sub(rf"[{re.escape(string.punctuation)}]", " ", s)
    s = re.sub(r"\b(a|an|the)\b", " ", s)
    return " ".join(s.split())

def popqa_match(pred: Union[str, List[str]],
                gold_list: Union[str, List[str]]) -> bool:
    """
    Keyword-overlap match between prediction(s) and PopQA possible_answers.

    pred:
        - a single string answer, or
        - a list of answer strings (e.g., ["journalist", "broadcaster"])
    gold_list:
        - a Python list of strings, or
        - a string representation of a list (e.g., '["journalist","journo"]')
    """
    # Handle missing inputs
    if not pred or not gold_list:
        return False

    # Normalize pred into a Python list of strings
    if isinstance(pred, str):
        pred_list = [pred]
    elif isinstance(pred, (list, tuple)):
        pred_list = [p for p in pred if isinstance(p, str)]
    else:
        return False

    if not pred_list:
        return False

    # Normalize gold_list into a Python list of strings
    if isinstance(gold_list, str):
        try:
            parsed = ast.literal_eval(gold_list)
            gold = parsed if isinstance(parsed, list) else [gold_list]
        except Exception:
            gold = [gold_list]
    elif isinstance(gold_list, (list, tuple)):
        gold = list(gold_list)
    else:
        return False

    if not gold:
        return False

    # Precompute normalized gold tokens
    gold_tokens_list: List[Set[str]] = []
    for g in gold:
        g_norm = normalize_text(g)
        if not g_norm:
            continue
        gold_tokens_list.append(set(g_norm.split()))

    if not gold_tokens_list:
        return False

    # Check each predicted answer against all gold aliases
    for p in pred_list:
        p_norm = normalize_text(p)
        if not p_norm:
            continue
        pred_tokens: Set[str] = set(p_norm.split())

        for gold_tokens in gold_tokens_list:
            if pred_tokens & gold_tokens:
                return True

    return False



# 2) Main evaluation loop with tqdm, logging to file only
ems = []
BM25_TOP_K = 100
FUSION_RRF_TOP_K = 3

for s in tqdm(dataset["test"].select(range(25)), desc="Evaluating"):
    question = s["question"]
    gold = s.get("possible_answers", [])

    bm25 = sreg.bm25_retrieve_and_rank(question, bm25_k=100)
    dense = sreg.dense_retrieve_and_rank(question, bm25)

    # NEW: fuse BM25 and dense before classification
    fused_docs = sreg.fuse_bm25_and_dense(dense, top_k=3)  # keep top 50 docs

    # Now score sentences only inside those fused docs
    scored = sreg.score_candidates(question, fused_docs) # yes/no threshold

    snippets = compose_context_adaptive(scored)
    pred = get_candidates(question, snippets)


    # Compute EM
    em = 1.0 if popqa_match(pred, gold) else 0.0
    ems.append(em)

    # ---- Logging block (goes ONLY to qa_results.log) ----
    logger.info(f"Q: {question}")
    logger.info("Snippets:")
    for sn in snippets:
        # handle both dict and plain string snippets
        if isinstance(sn, dict):
            text = sn.get("contents", "").strip()
        else:
            text = str(sn).strip()
        logger.info(f"- {text}")
    logger.info("Pred: %s", json.dumps(pred, ensure_ascii=False))
    logger.info(f"Gold: {gold}")
    logger.info(f"EM: {em}")
    logger.info("-" * 40)

# Overall EM
overall = 100.0 * sum(ems) / len(ems) if ems else 0.0
logger.info(f"Overall EM {overall:.2f} on {len(ems)} examples")
