In [None]:

!pip install beir rank_bm25

"""
STAGE 2 — RERANKER (MRF with frozen BERT)

- Uses BEIR datasets (quora, trec-covid)
- First stage: we assume you already have a retriever (DeepImpact or BM25)
  that gives you top-K candidates per query.
- Second stage (this file): re-ranks those candidates using the MRF model of
  Metzler & Croft with three feature families:

    * T(q_i, d)  : unigram (independent terms)
    * O(q_i,q_{i+1}, d): ordered term pairs (phrases)
    * U(q_i,q_j, d): unordered term pairs (proximity)

- Potentials are computed from frozen BERT embeddings:
    - Document vector v_d = average embedding of doc tokens
    - Query-term vector e_i from BERT
    - φ_T, φ_O, φ_U are cosine-sim based, clamped to be non-negative.

- We directly optimize λ_T, λ_O, λ_U to maximize MAP using greedy
  coordinate ascent under the constraint:
      λ_T + λ_O + λ_U = 1,  λ_k >= 0

- Finally, we report MAP and NDCG@10 """

import os
import json
import math
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

from tqdm import tqdm
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from rank_bm25 import BM25Okapi

# General utilities

def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_beir_dataset(name: str, split: str = "test"):
    """
    Download + load BEIR dataset properly; no manual paths.
    For this assignment, we use:
      - quora: dev as test (BEIR usually exposes a 'test' split; you can
               change 'split' to "dev" if needed depending on the dataset files)
      - trec-covid: first 8 queries only for eval
    """
    print(f"Downloading/loading BEIR dataset: {name}")
    data_path = util.download_and_unzip(
        f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{name}.zip",
        "./datasets"
    )
    loader = GenericDataLoader(data_path)
    corpus, queries, qrels = loader.load(split=split)
    return corpus, queries, qrels

# Stage 1 candidates via BM25

def build_bm25(corpus: Dict[str, Dict]) -> Tuple[BM25Okapi, List[str]]:
    """Simple BM25 over BEIR corpus (title + text)."""
    doc_ids = list(corpus.keys())
    docs = []
    for did in doc_ids:
        title = corpus[did].get("title", "") or ""
        text = corpus[did].get("text", "") or ""
        docs.append((title + " " + text).strip())
    tokenized = [doc.split() for doc in docs]
    bm25 = BM25Okapi(tokenized)
    return bm25, doc_ids


def get_bm25_candidates(
    corpus: Dict[str, Dict],
    queries: Dict[str, str],
    top_k: int = 100
) -> Dict[str, List[str]]:
    """
    First-stage retrieval with BM25, just to get candidate sets.
    Replace this with your DeepImpact top-100 list if you have it saved.
    """
    bm25, doc_ids = build_bm25(corpus)
    candidates = {}
    print("Building BM25 candidates...")
    for qid, qtext in tqdm(queries.items()):
        toks = qtext.split()
        scores = bm25.get_scores(toks)
        pairs = list(zip(doc_ids, scores))
        pairs.sort(key=lambda x: x[1], reverse=True)
        top_docs = [d for d, _ in pairs[:top_k]]
        candidates[qid] = top_docs
    return candidates

# BERT encoder for doc/query

class FrozenBERTEncoder:
    """
    Simple wrapper around BERT to produce:
      - document vectors v_d (mean-pooled token embeddings)
      - query token vectors [e_1, ..., e_n] for "terms" (excluding specials)
    """

    def __init__(self, model_name: str = "bert-base-uncased"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()
        self.special_ids = set([
            self.tokenizer.cls_token_id,
            self.tokenizer.sep_token_id,
            self.tokenizer.pad_token_id,
        ])

    def doc_vector(self, text: str, max_length: int = 256) -> torch.Tensor:
        """Mean-pooled token embeddings (excluding padding)."""
        enc = self.tokenizer(
            text,
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt",
        ).to(self.device)

        with torch.no_grad():
            h = self.model(**enc).last_hidden_state  # (1, L, H)

        mask = enc["attention_mask"].unsqueeze(-1)  # (1, L, 1)
        masked_h = h * mask
        summed = masked_h.sum(dim=1)               # (1, H)
        counts = mask.sum(dim=1).clamp(min=1)      # (1, 1)
        avg = summed / counts
        return avg[0].cpu()                        # (H,)

    def query_term_vectors(self, text: str, max_length: int = 32):
        """
        Returns:
            term_vecs:  list of (vec, position_index)
        We exclude special tokens like [CLS], [SEP], [PAD].
        """
        enc = self.tokenizer(
            text,
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt",
        ).to(self.device)

        with torch.no_grad():
            h = self.model(**enc).last_hidden_state[0]  # (L, H)

        ids = enc["input_ids"][0]
        term_vecs = []
        for i, tid in enumerate(ids.tolist()):
            if tid in self.special_ids:
                continue
            term_vecs.append((h[i].cpu(), i))

        return term_vecs   # list of (tensor(H), pos)

# Potential functions φ_T, φ_O, φ_U

def cosine_nonneg(a: torch.Tensor, b: torch.Tensor) -> float:
    """cosine similarity clamped to [0, 1] using ReLU."""
    a = a.unsqueeze(0)
    b = b.unsqueeze(0)
    cos = F.cosine_similarity(a, b, dim=1).item()  # in [-1, 1]
    return max(cos, 0.0)                           # non-negative as required


def compute_potentials_for_doc(
    q_terms: List[Tuple[torch.Tensor, int]],
    v_d: torch.Tensor,
    window: int = 2
) -> Tuple[float, float, float]:
    """
    Given query term vectors and a single doc vector, compute:

        sum_T φ_T(q_i, d)
        sum_O φ_O(q_i, q_{i+1}, d) for ordered pairs within window
        sum_U φ_U(q_i, q_j, d) for unordered pairs within window

    Definitions we use (simple, valid ones for the assignment):

        φ_T(q_i, d)  = max( cos(e_i, v_d), 0 )

        φ_O(q_i, q_{i+1}, d) =
            max( cos( (e_i + e_{i+1}) / 2, v_d ), 0 )

        φ_U(q_i, q_j, d) =
            max( cos( (e_i + e_j) / 2, v_d ), 0 )
            for j > i and (pos_j - pos_i) <= window
    """
    n = len(q_terms)
    if n == 0:
        return 0.0, 0.0, 0.0

    # Unigrams
    sum_T = 0.0
    for (vec_i, _) in q_terms:
        sum_T += cosine_nonneg(vec_i, v_d)

    # Ordered pairs (adjacent) within window
    sum_O = 0.0
    for idx in range(n - 1):
        (v_i, pos_i) = q_terms[idx]
        (v_j, pos_j) = q_terms[idx + 1]
        if (pos_j - pos_i) <= window:   # local phrase
            pair_vec = (v_i + v_j) / 2.0
            sum_O += cosine_nonneg(pair_vec, v_d)

    # Unordered pairs (within window but not necessarily adjacent)
    sum_U = 0.0
    for i in range(n):
        v_i, pos_i = q_terms[i]
        for j in range(i + 1, n):
            v_j, pos_j = q_terms[j]
            if (pos_j - pos_i) <= window:
                pair_vec = (v_i + v_j) / 2.0
                sum_U += cosine_nonneg(pair_vec, v_d)

    return sum_T, sum_O, sum_U


def mrf_score(
    lambdas: Tuple[float, float, float],
    sum_T: float,
    sum_O: float,
    sum_U: float
) -> float:
    λ_T, λ_O, λ_U = lambdas
    return λ_T * sum_T + λ_O * sum_O + λ_U * sum_U

# Evaluation helpers (MAP, NDCG)

def average_precision(ranked_docs: List[str], rel_set: set) -> float:
    """Compute AP for a single query."""
    if not rel_set:
        return 0.0

    hits = 0
    sum_prec = 0.0
    for rank, doc_id in enumerate(ranked_docs, start=1):
        if doc_id in rel_set:
            hits += 1
            sum_prec += hits / rank

    if hits == 0:
        return 0.0
    return sum_prec / len(rel_set)


def ndcg_at_k(ranked_docs: List[str], rel_scores: Dict[str, int], k: int = 10) -> float:
    """Compute NDCG@k with graded relevance."""
    dcg = 0.0
    for rank, d in enumerate(ranked_docs[:k], start=1):
        rel = rel_scores.get(d, 0)
        dcg += (2 ** rel - 1) / math.log2(rank + 1)

    # Ideal DCG
    ideal_rels = sorted(rel_scores.values(), reverse=True)
    idcg = 0.0
    for rank, rel in enumerate(ideal_rels[:k], start=1):
        idcg += (2 ** rel - 1) / math.log2(rank + 1)

    if idcg == 0.0:
        return 0.0
    return dcg / idcg


def evaluate_for_lambdas(
    lambdas: Tuple[float, float, float],
    encoder: FrozenBERTEncoder,
    doc_vecs: Dict[str, torch.Tensor],
    queries: Dict[str, str],
    qrels: Dict[str, Dict[str, int]],
    candidates: Dict[str, List[str]],
    window: int = 2,
) -> Tuple[float, float]:
    """
    Evaluate MAP and NDCG@10 for a given λ vector.
    """
    ap_list = []
    ndcg_list = []

    for qid, qtext in queries.items():
        if qid not in candidates:
            continue
        cands = candidates[qid]
        rel_docs = {d for d, r in qrels.get(qid, {}).items() if r > 0}

        # skip queries with no relevance info
        if not rel_docs:
            continue

        # compute query term vectors once
        q_terms = encoder.query_term_vectors(qtext)

        # score all candidate docs
        doc_scores = []
        for did in cands:
            v_d = doc_vecs[did]
            sum_T, sum_O, sum_U = compute_potentials_for_doc(q_terms, v_d, window)
            s = mrf_score(lambdas, sum_T, sum_O, sum_U)
            doc_scores.append((did, s))

        doc_scores.sort(key=lambda x: x[1], reverse=True)
        ranked = [d for d, _ in doc_scores]

        ap = average_precision(ranked, rel_docs)
        ndcg = ndcg_at_k(ranked, qrels.get(qid, {}), k=10)

        ap_list.append(ap)
        ndcg_list.append(ndcg)

    if not ap_list:
        return 0.0, 0.0

    return sum(ap_list) / len(ap_list), sum(ndcg_list) / len(ndcg_list)

# Coordinate ascent over λ_T, λ_O, λ_u

def normalize_lambdas(l: List[float]) -> Tuple[float, float, float]:
    total = sum(max(x, 0.0) for x in l)
    if total == 0:
        return (1.0, 0.0, 0.0)
    return tuple(max(x, 0.0) / total for x in l)


def coordinate_ascent_lambdas(
    encoder: FrozenBERTEncoder,
    doc_vecs: Dict[str, torch.Tensor],
    queries: Dict[str, str],
    qrels: Dict[str, Dict[str, int]],
    candidates: Dict[str, List[str]],
    window: int = 2,
    step: float = 0.1,
    max_iter: int = 20,
) -> Tuple[Tuple[float, float, float], float, float]:
    """
    Greedy hill-climbing for λ = (λ_T, λ_O, λ_U), maximizing MAP (primary).
    We still track NDCG@10 for reporting.
    """
    # Initialize with purely independent model
    lambdas = (1.0, 0.0, 0.0)
    best_map, best_ndcg = evaluate_for_lambdas(
        lambdas, encoder, doc_vecs, queries, qrels, candidates, window
    )
    print(f"Initial lambdas: {lambdas}, MAP={best_map:.4f}, NDCG@10={best_ndcg:.4f}")

    for it in range(max_iter):
        improved = False
        print(f"\n[CoordAscent] Iteration {it+1}")

        for idx in range(3):  # 0: T, 1: O, 2: U
            for delta in (+step, -step):
                trial = list(lambdas)
                trial[idx] += delta
                trial = list(normalize_lambdas(trial))
                trial_tup = tuple(trial)

                trial_map, trial_ndcg = evaluate_for_lambdas(
                    trial_tup, encoder, doc_vecs, queries, qrels, candidates, window
                )

                print(
                    f"  Trying λ={trial_tup}, "
                    f"MAP={trial_map:.4f}, NDCG@10={trial_ndcg:.4f}"
                )

                if trial_map > best_map + 1e-4:  # tiny tolerance
                    print("   -> Improvement found, accepting.")
                    lambdas = trial_tup
                    best_map, best_ndcg = trial_map, trial_ndcg
                    improved = True

        if not improved:
            print("No further improvement, stopping.")
            break

    print(
        f"\nFinal lambdas: {lambdas}, "
        f"Best MAP={best_map:.4f}, Best NDCG@10={best_ndcg:.4f}"
    )
    return lambdas, best_map, best_ndcg

# Precompute doc vectors for all candidate docs

def precompute_doc_vectors(
    encoder: FrozenBERTEncoder,
    corpus: Dict[str, Dict],
    candidates: Dict[str, List[str]],
) -> Dict[str, torch.Tensor]:
    all_docs = set()
    for qid, docs in candidates.items():
        all_docs.update(docs)

    doc_vecs = {}
    print(f"Precomputing doc vectors for {len(all_docs)} candidate docs...")
    for did in tqdm(all_docs):
        entry = corpus[did]
        text = (entry.get("title", "") or "") + " " + (entry.get("text", "") or "")
        v_d = encoder.doc_vector(text)
        doc_vecs[did] = v_d
    return doc_vecs


# MAIN

def main():
    set_seed(42)

    # 1. Load datasets
    corpus_quora, queries_quora_full, qrels_quora = load_beir_dataset("quora", split="test")
    corpus_trec, queries_trec_full, qrels_trec = load_beir_dataset("trec-covid", split="test")

    # For trec-covid, use only first 8 queries as per assignment
    trec_qids = sorted(queries_trec_full.keys())[:8]
    queries_trec = {qid: queries_trec_full[qid] for qid in trec_qids}
    qrels_trec = {qid: qrels_trec.get(qid, {}) for qid in trec_qids}

    # Limit quora queries for faster execution during demonstration
    quora_qids = sorted(queries_quora_full.keys())[:100]  # Limiting to 100 queries
    queries_quora = {qid: queries_quora_full[qid] for qid in quora_qids}

    # Filter qrels_quora to match the limited queries_quora
    qrels_quora_filtered = {qid: qrels_quora.get(qid, {}) for qid in quora_qids}
    qrels_quora = qrels_quora_filtered

    # Merge corpora for convenience (IDs are disjoint in BEIR)
    corpus = {**corpus_quora, **corpus_trec}

    # 2. Stage 1: candidate generation
    print("\nBuilding candidates for quora...")
    candidates_quora = get_bm25_candidates(corpus_quora, queries_quora, top_k=100)

    print("Building candidates for trec-covid (8 queries)...")
    candidates_trec = get_bm25_candidates(corpus_trec, queries_trec, top_k=100)

    # Merge candidate dictionaries
    candidates = {**candidates_quora, **candidates_trec}
    queries_all = {**queries_quora, **queries_trec}
    qrels_all = {**qrels_quora, **qrels_trec}

    # 3. Frozen BERT encoder + doc vectors
    encoder = FrozenBERTEncoder("bert-base-uncased")
    doc_vecs = precompute_doc_vectors(encoder, corpus, candidates)

    # 4. Coordinate ascent to learn λ_T, λ_O, λ_U
    best_lambdas, best_map, best_ndcg = coordinate_ascent_lambdas(
        encoder=encoder,
        doc_vecs=doc_vecs,
        queries=queries_all,
        qrels=qrels_all,
        candidates=candidates,
        window=2,     # you can experiment with different window sizes
        step=0.1,
        max_iter=15,
    )

    # 5. Per-dataset reporting (MAP, NDCG@10)
    print("\n=== Final Evaluation per dataset with learned lambdas ===")
    print(f"Learned lambdas: λ_T={best_lambdas[0]:.3f}, "
          f"λ_O={best_lambdas[1]:.3f}, λ_U={best_lambdas[2]:.3f}\n")

    # Quora
    map_q, ndcg_q = evaluate_for_lambdas(
        best_lambdas, encoder, doc_vecs, queries_quora, qrels_quora, candidates_quora
    )
    print(f"[QUORA] MAP = {map_q:.4f}, NDCG@10 = {ndcg_q:.4f}")

    # TREC-COVID (8 queries)
    map_t, ndcg_t = evaluate_for_lambdas(
        best_lambdas, encoder, doc_vecs, queries_trec, qrels_trec, candidates_trec
    )
    print(f"[TREC-COVID-8] MAP = {map_t:.4f}, NDCG@10 = {ndcg_t:.4f}")

    print("\nStage 2 reranking complete ✔")


if __name__ == "__main__":
    main()


Collecting beir
  Downloading beir-2.2.0-py3-none-any.whl.metadata (28 kB)
Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Collecting pytrec-eval-terrier (from beir)
  Downloading pytrec_eval_terrier-0.5.10-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (1.1 kB)
Downloading beir-2.2.0-py3-none-any.whl (77 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.4/77.4 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Downloading pytrec_eval_terrier-0.5.10-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (304 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m304.8/304.8 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rank_bm25, pytrec-eval-terrier, beir
Successfully installed beir-2.2.0 pytrec-eval-terrier-0.5.10 rank_bm25-0.2.2
Downloading/loading BEIR dataset: quora


./datasets/quora.zip:   0%|          | 0.00/15.1M [00:00<?, ?iB/s]

  0%|          | 0/522931 [00:00<?, ?it/s]

Downloading/loading BEIR dataset: trec-covid


./datasets/trec-covid.zip:   0%|          | 0.00/70.5M [00:00<?, ?iB/s]

  0%|          | 0/171332 [00:00<?, ?it/s]


Building candidates for quora...
Building BM25 candidates...


100%|██████████| 100/100 [02:57<00:00,  1.78s/it]


Building candidates for trec-covid (8 queries)...
Building BM25 candidates...


100%|██████████| 8/8 [00:08<00:00,  1.01s/it]
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Precomputing doc vectors for 10579 candidate docs...


100%|██████████| 10579/10579 [2:59:32<00:00,  1.02s/it]


Initial lambdas: (1.0, 0.0, 0.0), MAP=0.5387, NDCG@10=0.5961

[CoordAscent] Iteration 1
  Trying λ=(1.0, 0.0, 0.0), MAP=0.5387, NDCG@10=0.5961
  Trying λ=(1.0, 0.0, 0.0), MAP=0.5387, NDCG@10=0.5961
  Trying λ=(0.9090909090909091, 0.09090909090909091, 0.0), MAP=0.5387, NDCG@10=0.5961
  Trying λ=(1.0, 0.0, 0.0), MAP=0.5387, NDCG@10=0.5961
  Trying λ=(0.9090909090909091, 0.0, 0.09090909090909091), MAP=0.5390, NDCG@10=0.5962
   -> Improvement found, accepting.
  Trying λ=(1.0, 0.0, 0.0), MAP=0.5387, NDCG@10=0.5961

[CoordAscent] Iteration 2
  Trying λ=(0.9173553719008265, 0.0, 0.08264462809917357), MAP=0.5387, NDCG@10=0.5959
  Trying λ=(0.898989898989899, 0.0, 0.10101010101010101), MAP=0.5390, NDCG@10=0.5962
  Trying λ=(0.8264462809917354, 0.09090909090909091, 0.08264462809917356), MAP=0.5392, NDCG@10=0.5964
   -> Improvement found, accepting.
  Trying λ=(0.9090909090909092, 0.0, 0.09090909090909093), MAP=0.5390, NDCG@10=0.5962
  Trying λ=(0.7513148009015778, 0.08264462809917357, 0.1660405