In [None]:
!pip install -q -r requirements.txt

In [None]:
import os
from datasets import load_dataset
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Literal
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import faiss
import re
from collections import Counter
from tqdm import tqdm
from openai import OpenAI
import requests

In [None]:
# Dataclasses to make things easier
@dataclass
class Passage:
    pid: str
    text: str
    meta: Dict[str, Any]

@dataclass
class Example:
    qid: str
    question: str
    answer: Optional[str]
    context: List[Passage]
    time: Optional[str]
    meta: Dict[str, Any]

@dataclass
class RetrievalResult:
    pid: str
    score: float
    text: str
    meta: Dict[str, Any]

@dataclass
class QATrace:
    qid: str
    question: str
    hop_queries: List[str]
    retrieved: List[RetrievalResult]
    evidence: List[RetrievalResult]
    predicted_answer: str

In [None]:
def load_hotpotqa(hf_id: str = "hotpotqa/hotpot_qa", config: str = "fullwiki", train_split: str = "train", test_split: str = "validation", max_context_passages: Optional[int] = None) -> Dict[str, List[Example]]:
    ds_train = load_dataset(hf_id, config, split=train_split)
    ds_test = load_dataset(hf_id, config, split=test_split)

    def row_to_example(row) -> Example:
        qid = str(row["id"])
        question = row["question"]
        answer = row.get("answer", None)

        titles, sents_lists = row["context"]
        passages: List[Passage] = []
        for i, (title, sents) in enumerate(zip(titles, sents_lists)):
            text = " ".join(sents) if isinstance(sents, list) else str(sents)
            passages.append(Passage(
                pid=f"{qid}::ctx::{i}",
                text=text,
                meta={"title": title}
            ))
            if max_context_passages and len(passages) >= max_context_passages:
                break

        meta = {
            "type": row.get("type"),
            "level": row.get("level"),
            "supporting_facts": row.get("supporting_facts"),
        }

        return Example(
            qid=qid,
            question=question,
            answer=answer,
            context=passages,
            time=None,
            meta=meta,
        )

    return {
        "train": [row_to_example(r) for r in ds_train],
        "test":  [row_to_example(r) for r in ds_test],
    }

In [None]:
def load_chronicling_america_qa(hf_id: str = "Bhawna/ChroniclingAmericaQA", train_split: str = "train", test_split: str = "test", use_clean_context: bool = True) -> Dict[str, List[Example]]:
    ds_train = load_dataset(hf_id, split=train_split)
    ds_test = load_dataset(hf_id, split=test_split)

    def row_to_example(row) -> Example:
        qid = row["query_id"]
        question = row["question"]

        answer = row.get("answer") or row.get("org_answer")

        text = row.get("context") if use_clean_context else row.get("raw ocr")
        if text is None:
            text = ""

        pid = str(row.get("para id", qid))
        pub_date = row.get("publication date")

        passage = Passage(
            pid=pid,
            text=text,
            meta={"url": row.get("url"), "raw_ocr": row.get("raw ocr") if use_clean_context else None}
        )

        meta = {
            "trans_que": row.get("trans que"),
            "trans_ans": row.get("trans ans"),
            "url": row.get("url"),
            "para_id": row.get("para id"),
        }

        return Example(
            qid=qid,
            question=question,
            answer=answer,
            context=[passage],
            time=str(pub_date) if pub_date is not None else None,
            meta=meta,
        )

    return {
        "train": [row_to_example(r) for r in ds_train],
        "test":  [row_to_example(r) for r in ds_test],
    }


In [None]:
def load_timeqa(hf_id: str = "hugosousa/TimeQA", train_split: str = "train", test_split: str = "test") -> Dict[str, List[Example]]:
    ds_train = load_dataset(hf_id, split=train_split)
    ds_test = load_dataset(hf_id, split=test_split)

    def normalize_answer(targets) -> Optional[str]:
        if targets is None:
            return None
        if isinstance(targets, str):
            return targets
        if isinstance(targets, list):
            return targets[0] if len(targets) > 0 else None
        return str(targets)

    def row_to_example(row) -> Example:
        qid = str(row["idx"])
        question = row["question"]
        answer = normalize_answer(row.get("targets"))

        ctx = row.get("context")
        passages: List[Passage] = []
        if isinstance(ctx, list):
            for i, t in enumerate(ctx):
                passages.append(Passage(pid=f"{qid}::p{i}", text=str(t), meta={}))
        else:
            passages.append(Passage(pid=f"{qid}::p0", text=str(ctx) if ctx is not None else "", meta={}))

        meta = {"level": row.get("level")}
        return Example(
            qid=qid,
            question=question,
            answer=answer,
            context=passages,
            time=None,
            meta=meta,
        )

    return {
        "train": [row_to_example(r) for r in ds_train],
        "test":  [row_to_example(r) for r in ds_test],
    }

In [None]:
def examples_to_passages(examples: List[Example]) -> List[Dict[str, Any]]:
    # Use all passages from the split you index (train recommended for in-domain)
    seen = set()
    out = []
    for ex in examples:
        for p in ex.context:
            if p.pid in seen:
                continue
            seen.add(p.pid)
            out.append({"pid": p.pid, "text": p.text, "meta": p.meta})
    return out

def build_passage_store(passages: List[Dict[str, Any]]):
    pids = [p["pid"] for p in passages]
    texts = [p["text"] for p in passages]
    metas = [p.get("meta", {}) for p in passages]
    pid_to_text = {pid: txt for pid, txt in zip(pids, texts)}
    pid_to_meta = {pid: m for pid, m in zip(pids, metas)}
    return pids, texts, metas, pid_to_text, pid_to_meta

In [None]:
# BM25 retriever
class BM25Retriever:
    def __init__(self, pids, texts):
        self.pids = pids
        self.texts = texts
        tokenized = [t.lower().split() for t in texts]
        self.bm25 = BM25Okapi(tokenized)

    def retrieve(self, query, k=10):
        q_tok = query.lower().split()
        scores = self.bm25.get_scores(q_tok)
        top_idx = np.argsort(scores)[::-1][:k]
        return [(self.pids[i], float(scores[i]), i) for i in top_idx]

In [None]:
class DenseRetriever:
    def __init__(self, pids, texts, model_name=os.getenv("DENSE_MODEL"), batch_size=64):
        self.pids = pids
        self.texts = texts
        self.model = SentenceTransformer(model_name, device="cpu")

        emb = self.model.encode(
            texts,
            batch_size=batch_size,
            normalize_embeddings=True,
            show_progress_bar=True
        ).astype("float32")

        self.dim = emb.shape[1]
        self.index = faiss.IndexFlatIP(self.dim)
        self.index.add(emb)

    def retrieve(self, query, k=10):
        q = self.model.encode([query], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q, k)
        idxs = idxs[0]
        scores = scores[0]
        return [(self.pids[i], float(scores[j]), int(i)) for j, i in enumerate(idxs)]

In [None]:
class HybridRetriever:
    def __init__(self, bm25_ret: BM25Retriever, dense_ret: DenseRetriever, alpha=0.5):
        self.bm25_ret = bm25_ret
        self.dense_ret = dense_ret
        self.alpha = alpha

    @staticmethod
    def _normalize(scores):
        s = np.array(scores, dtype="float32")
        if len(s) == 0:
            return s

        mean = s.mean()
        std = s.std()
        if std < 1e-8:
            return np.zeros_like(s)

        z = (s - mean) / std
        return 1 / (1 + np.exp(-z))

    def retrieve(self, query, k=10, candidate_k=50):
        bm = self.bm25_ret.retrieve(query, k=candidate_k)
        de = self.dense_ret.retrieve(query, k=candidate_k)

        bm_map = {pid: score for pid, score, _ in bm}
        de_map = {pid: score for pid, score, _ in de}

        # union of candidates
        all_pids = list(set(bm_map.keys()) | set(de_map.keys()))
        bm_scores = [bm_map.get(pid, 0.0) for pid in all_pids]
        de_scores = [de_map.get(pid, 0.0) for pid in all_pids]

        bm_norm = self._normalize(bm_scores)
        de_norm = self._normalize(de_scores)

        hybrid = self.alpha * bm_norm + (1.0 - self.alpha) * de_norm
        order = np.argsort(hybrid)[::-1][:k]

        return [(all_pids[i], float(hybrid[i]), None) for i in order]

In [None]:
class QueryAgent:
    def make_query(self, question: str, hop_context: Optional[str] = None) -> str:
        return question

In [None]:
class RetrieverAdapter:
    def __init__(self, retriever, pid_to_text, pid_to_meta):
        self.retriever = retriever
        self.pid_to_text = pid_to_text
        self.pid_to_meta = pid_to_meta

    def retrieve(self, query: str, k: int) -> List[RetrievalResult]:
        raw = self.retriever.retrieve(query, k=k)
        out = []
        for pid, score, _ in raw:
            out.append(RetrievalResult(
                pid=pid,
                score=score,
                text=self.pid_to_text.get(pid, ""),
                meta=self.pid_to_meta.get(pid, {}),
            ))
        return out

In [None]:
class EvidenceSelector:
    def __init__(self, top_n: int = 3):
        self.top_n = top_n

    def select(self, retrieved: List[RetrievalResult]) -> List[RetrievalResult]:
        return retrieved[: self.top_n]

In [None]:
import ollama

def _build_context(evidence, max_chars_per_passage: int = 2000) -> str:
    blocks: List[str] = []
    for i, r in enumerate(evidence, start=1):
        txt = (r.text or "")[:max_chars_per_passage]
        blocks.append(f"[Passage {i}]\n{txt}")
    return "\n\n".join(blocks)

def _build_prompt(question: str, context: str) -> List[dict[str, str]]:
    return [
        {
            "role": "system",
            "content": "You are a question answering assistant. Use only the provided passages as evidence. If the answer is not in the passages, say 'Insufficient information.'"
        },
        {
            "role": "user",
            "content": f"Question: {question}\n\nPassages:\n{context}\n\nAnswer (concise):"
        }
    ]

class AnswerGenerator:
    def __init__(
        self,
        backend: Literal["ollama", "openai"] = "ollama",
        ollama_model: str = os.getenv("OLLAMA_MODEL", "gemma3:4b"),
        ollama_url: str = "http://localhost:11434",
        openai_model: str = "gpt-4o-mini",
        openai_api_key: Optional[str] = None,
        temperature: float = 0.0,
        max_passages: int = 5,
        max_chars_per_passage: int = 2000,
    ):
        self.backend = backend
        self.ollama_model = ollama_model
        self.ollama_url = ollama_url.rstrip("/")
        self.openai_model = openai_model
        self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
        self.temperature = temperature
        self.max_passages = max_passages
        self.max_chars_per_passage = max_chars_per_passage

    def answer(self, question: str, evidence) -> str:
        evidence = evidence[: self.max_passages]
        context = _build_context(evidence, max_chars_per_passage=self.max_chars_per_passage)
        message = _build_prompt(question, context)

        if self.backend == "ollama":
            return self._answer_ollama(message)
        elif self.backend == "openai":
            return self._answer_openai(message)
        else:
            raise ValueError(f"Unknown backend: {self.backend}")

    def _answer_ollama(self, message: List[dict[str, str]]) -> str:
        # payload = {
        #     "model": self.ollama_model,
        #     "messages": messages,
        #     "stream": False,
        #     "options": {"temperature": self.temperature},
        # }
        resp = ollama.chat(
            model = self.ollama_model,
            messages = message,
            options = {"temperature": self.temperature},
        )
        #resp = requests.post(f"{self.ollama_url}/api/chat", json=payload)
        #resp.raise_for_status()
        #data = resp.json()
        return resp.message.content.strip()

    def _answer_openai(self, messages: List[dict[str, str]]) -> str:
        if not self.openai_api_key:
            raise RuntimeError("OPENAI_API_KEY not set. Set env var or pass openai_api_key=...")

        client = OpenAI(api_key=self.openai_api_key)
        resp = client.chat.completions.create(
            model=self.openai_model,
            messages=messages,
            temperature=self.temperature,
        )
        return resp.choices[0].message.content.strip()


In [None]:
class QASystem:
    def __init__(self, query_agent, retriever_adapter, evidence_selector, answer_generator):
        self.query_agent = query_agent
        self.retriever = retriever_adapter
        self.evidence_selector = evidence_selector
        self.answer_generator = answer_generator

    def run_one(self, qid: str, question: str, top_k: int = 10) -> QATrace:
        q = self.query_agent.make_query(question)
        retrieved = self.retriever.retrieve(q, k=top_k)
        evidence = self.evidence_selector.select(retrieved)
        pred = self.answer_generator.answer(question, evidence)
        return QATrace(qid=qid, question=question, hop_queries=[q], retrieved=retrieved, evidence=evidence, predicted_answer=pred)

In [None]:
def _normalize_text(s: str) -> str:
    s = s.lower().strip()
    s = re.sub(r"[^a-z0-9\s]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def exact_match(pred: str, gold: str) -> float:
    return 1.0 if _normalize_text(pred) == _normalize_text(gold) else 0.0

def f1_score(pred: str, gold: str) -> float:
    pred_toks = _normalize_text(pred).split()
    gold_toks = _normalize_text(gold).split()
    if not pred_toks and not gold_toks:
        return 1.0
    if not pred_toks or not gold_toks:
        return 0.0
    common = Counter(pred_toks) & Counter(gold_toks)
    num_same = sum(common.values())
    if num_same == 0:
        return 0.0
    precision = num_same / len(pred_toks)
    recall = num_same / len(gold_toks)
    return 2 * precision * recall / (precision + recall)

def extract_years(text: str) -> List[int]:
    return [int(y) for y in re.findall(r"\b(19\d{2}|20\d{2})\b", text)]

def parse_year_from_meta(meta: Dict[str, Any]) -> Optional[int]:
    if "year" in meta:
        try:
            return int(meta["year"])
        except:
            pass
    for k in ["publication_date", "publication date", "date", "time"]:
        if k in meta and meta[k]:
            ys = extract_years(str(meta[k]))
            if ys:
                return ys[0]
    return None

def temporal_score(question: str, evidence: List[RetrievalResult]) -> Optional[float]:
    q_years = set(extract_years(question))
    if not q_years:
        return None
    ev_years = set()
    for r in evidence:
        y = parse_year_from_meta(r.meta)
        if y is not None:
            ev_years.add(y)
    if not ev_years:
        return 0.0
    return 1.0 if len(q_years & ev_years) > 0 else 0.0

In [None]:
def run_eval(qa_system: QASystem, examples: List[Example], top_k: int = 10, max_examples: Optional[int] = None):
    em_sum = f1_sum = 0.0
    n_scored = 0
    t_sum = 0.0
    t_n = 0
    traces = []

    use = examples[:max_examples] if max_examples else examples
    for ex in tqdm(use):
        trace = qa_system.run_one(ex.qid, ex.question, top_k=top_k)

        if ex.answer is not None:
            em_sum += exact_match(trace.predicted_answer, ex.answer)
            f1_sum += f1_score(trace.predicted_answer, ex.answer)
            n_scored += 1

        ts = temporal_score(ex.question, trace.evidence)
        if ts is not None:
            t_sum += ts
            t_n += 1

        traces.append({
            "qid": ex.qid,
            "question": ex.question,
            "gold": ex.answer,
            "pred": trace.predicted_answer,
            "query": trace.hop_queries[0],
            "evidence_pids": [r.pid for r in trace.evidence],
            "evidence_years": [parse_year_from_meta(r.meta) for r in trace.evidence],
        })

    metrics = {
        "EM": em_sum / max(n_scored, 1),
        "F1": f1_sum / max(n_scored, 1),
        "Temporal@evidence": (t_sum / t_n) if t_n > 0 else None,
        "n_scored": n_scored,
        "n_temporal_scored": t_n,
        "top_k": top_k,
    }
    return metrics, traces

In [None]:
def run_dataset(train_examples: List[Example], test_examples: List[Example], retriever_kind="hybrid"):
    passages = examples_to_passages(train_examples)
    pids, texts, metas, pid_to_text, pid_to_meta = build_passage_store(passages)

    bm25 = BM25Retriever(pids, texts)
    dense = DenseRetriever(pids, texts, model_name="BAAI/bge-small-en-v1.5")
    hybrid = HybridRetriever(bm25, dense, alpha=0.5)

    base = {"bm25": bm25, "dense": dense, "hybrid": hybrid}[retriever_kind]
    retr_adapter = RetrieverAdapter(base, pid_to_text, pid_to_meta)

    answer_ollama = AnswerGenerator(backend="ollama")
    qa_system = QASystem(QueryAgent(), retr_adapter, EvidenceSelector(top_n=20), answer_ollama)

    metrics, traces = run_eval(qa_system, test_examples, top_k=20, max_examples=1000)
    print(retriever_kind, metrics)
    return metrics, traces

In [None]:
hot = load_hotpotqa()
metrics, traces = run_dataset(hot["train"], hot["test"], retriever_kind="bm25")