# The Revenge of Rocchio's Angels - COLAB EDITION

We will be #1 this time

### Setup

In [1]:
# Install OpenJDK 21
!apt-get update
!apt-get install openjdk-21-jdk-headless -qq > /dev/null

Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:2 https://cli.github.com/packages stable InRelease                         
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [2,302 kB]
Hit:5 http://security.ubuntu.com/ubuntu jammy-security InRelease               
Hit:6 http://archive.ubuntu.com/ubuntu jammy InRelease                         
Hit:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease                 
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease   
Hit:9 https://r2u.stat.illinois.edu/ubuntu jammy InRelease          
Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:11 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:12 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Fetched 2,303 kB in 2s (1

In [3]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
!java -version

openjdk version "21.0.9" 2025-10-21
OpenJDK Runtime Environment (build 21.0.9+10-Ubuntu-122.04)
OpenJDK 64-Bit Server VM (build 21.0.9+10-Ubuntu-122.04, mixed mode, sharing)


In [5]:
!pip install --extra-index-url https://download.pytorch.org/whl/cu126 accelerate torch python-dotenv faiss-cpu --no-cache torchvision pyserini==0.36.0 python-dotenv tqdm matplotlib seaborn sentence-transformers langchain-text-splitters

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu126


## All necessary scripts, in the following cells, to use in colab...

Not very comfortable but ok

### processing.py

In [8]:
import os
import ast
from langchain_text_splitters import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple, List, Union
from pathlib import Path
import json
from collections import defaultdict

load_dotenv()
SPLITTER_ARGS = {'chunk_size': 768, 'chunk_overlap': 50}
SPLITTER_SINGLETON = RecursiveCharacterTextSplitter(**SPLITTER_ARGS)

# Generic SGML-ish block: <TAG ...> ... </TAG>
# TAG names: letters/digits/_/-
_BLOCK = re.compile(
    r"<(?P<tag>[A-Za-z][A-Za-z0-9_-]*)(?P<attrs>\s+[^>]*)?>\s*(?P<content>.*?)\s*</(?P=tag)>",
    re.DOTALL
)

# Remove any remaining tags (inline or otherwise)
_ANY_TAG = re.compile(r"</?[^>]+>")

# Common “annotation-like” remnants you may want to drop from body
_TEXT_MARKER = re.compile(r"^\s*\[Text\]\s*", re.IGNORECASE)

@dataclass(slots=True)
class Hit:
    docid: str
    score: float
    qid: Optional[int] = None
    query: Optional[str] = None
    text: Optional[str] = None
    meta: Dict[str, Any] = field(default_factory=dict)

    def __getattr__(self, name: str):
        try:
            return self.meta[name]
        except KeyError:
            raise AttributeError(name)

    # For LangChain
    @property
    def page_content(self) -> str:
        return self.text

    @property
    def metadata(self) -> dict:
        return {"docid": self.docid, "query": self.query, **self.meta}



def create_llm_generated_queries(
    input_path: str | Path,
    out_paths: List[str | Path] | None = None,
    expected_cols: int = 3,
    sep: str = " ",
    encoding: str = "utf-8",
) -> Tuple[Path, ...]:
    """
    Read The query permutations generated by the llm:
        <qid>: option1, option2, option3, ...
    and write N separate TREC query files (one per option/"column"):
        qid<tab>option_i
    """
    input_path = Path(input_path)
    if out_paths is None:
        out_paths = [Path(f"queries_col{i+1}.txt") for i in range(expected_cols)]
    else:
        out_paths = [Path(p) for p in out_paths]

    columns: List[List[str]] = [[] for _ in range(expected_cols)]

    with input_path.open("r", encoding=encoding) as f:
        for line_no, raw in enumerate(f, start=1):
            line = raw.strip()
            if not line:
                continue

            if ":" not in line:
                raise ValueError(f"Malformed line {line_no}: missing ':' -> {line!r}")

            qid, rest = line.split(":", 1)
            qid = qid.strip()
            if not qid:
                raise ValueError(f"Malformed line {line_no}: empty qid -> {line!r}")

            options = [opt.strip() for opt in rest.split(",") if opt.strip() != ""]
            if len(options) != expected_cols:
                raise ValueError(
                    f"Line {line_no} (qid={qid}) has {len(options)} options, expected {expected_cols}: {line!r}"
                )

            for i in range(expected_cols):
                columns[i].append(f"{qid}{sep}{options[i]}")

    for out_path in out_paths:
        out_path.parent.mkdir(parents=True, exist_ok=True)

    for out_path, lines in zip(out_paths, columns):
        out_path.write_text("\n".join(lines) + ("\n" if lines else ""), encoding=encoding)

    return tuple(out_paths)

def write_topk_jsonl_query(hits, out_path, qid):
    """
    Appends one JSONL record:
      {"query": "<query or qid>", "hits": [{"docid": "...", "score": ...}, ...]}
    For retrieval checkpointing.
    """
    if not hits:
        return
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    query = hits[0].query
    seen = set()
    hit_list = []

    for h in hits:
        # defensive: ensure single-query invariant
        if h.query != query:
            raise ValueError("Hits contain multiple queries")

        if h.docid in seen:
            print("DUPLICATE (HOW?)")
            continue

        seen.add(h.docid)
        hit_list.append({
            "docid": str(h.docid),
            "score": float(h.score),
            "text": str(h.text),
        })

    rec = {
        "query": query,
        "qid": qid,
        "hits": hit_list,
    }

    with out_path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")


def iter_query_hits(jsonl_path: str | Path):
    """
    Stream a JSONL file line-by-line to yield previous resutls.
    """
    jsonl_path = Path(jsonl_path)

    with jsonl_path.open("r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue

            try:
                rec = json.loads(line)
            except json.JSONDecodeError as e:
                raise ValueError(f"Bad JSON on line {line_no} in {jsonl_path}") from e

            q = rec["query"]
            hits_raw = rec.get("hits", [])
            hits = [Hit(query=q, docid=str(h["docid"]), score=float(h["score"])) for h in hits_raw]

            yield rec["qid"], hits


def _normalize_ws(s: str) -> str:
    s = s.replace("\r\n", "\n").replace("\r", "\n")
    # Collapse spaces/tabs
    s = re.sub(r"[ \t]+", " ", s)
    # Collapse many blank lines
    s = re.sub(r"\n{3,}", "\n\n", s)
    return s.strip()

def clean_inner_text(s: str) -> str:
    """
    Cleans text inside a tag:
    - strips any nested tags like <F P=105> ... </F>
    - removes [Text] marker (common in newswire)
    - normalizes whitespace
    """
    s = _ANY_TAG.sub("", s)  # drop nested tags
    s = _TEXT_MARKER.sub("", s)  # drop leading [Text] marker
    return _normalize_ws(s)

def clean_robust(raw) -> Tuple[str, Dict[str, Union[str, List[str]]]]:
    """
    Extract ALL SGML-ish blocks.
      - <TEXT> blocks become the main body (concatenate if multiple)
      - every other tag becomes metadata[tag] (string or list of strings)
    Anything not inside blocks is ignored by default
    """
    metadata: Dict[str, Any] = {}
    body_parts: List[str] = []
    if not raw:
        return "", metadata

    # Find all blocks
    for m in _BLOCK.finditer(raw):
        tag = m.group("tag").strip().upper()
        content = m.group("content") or ""
        cleaned = clean_inner_text(content)

        if not cleaned:
            continue

        if tag == "TEXT":
            body_parts.append(cleaned)
        else:
            # store possibly repeated tags as list
            if tag in metadata:
                if isinstance(metadata[tag], list):
                    metadata[tag].append(cleaned)
                else:
                    metadata[tag] = [metadata[tag], cleaned]
            else:
                metadata[tag] = cleaned

    # If there was no <TEXT> tag, fall back to cleaning the whole raw as body
    # (useful when some corpora omit TEXT)
    if not body_parts:
        # Remove all blocks completely, then clean what remains
        stripped = _BLOCK.sub("", raw)
        stripped = clean_inner_text(stripped)
        return stripped, metadata

    body = "\n\n".join(body_parts)
    return body, metadata

def split_passages(hits: List[Hit], splitter=SPLITTER_SINGLETON):
    return splitter.split_documents(hits)


### engine.py

In [9]:
import os
from dotenv import load_dotenv
load_dotenv()
#os.environ["JAVA_HOME"] = os.getenv("JAVA_HOME")
from tqdm import tqdm
from pyserini.index.lucene import IndexReader
from pyserini.search.lucene import LuceneSearcher
#from processing import split_passages, clean_robust, Hit
from sentence_transformers import CrossEncoder
# from mxbai_rerank import MxbaiRerankV2
# from inranker import T5Ranker
import re
from collections import defaultdict
import torch


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)
CROSS_ENCODER = os.getenv("CROSS_ENCODER")
SUPPORTED_RERANKERS = ["CE", "QWEN_CE", "mxbai", "monot5", "twolar", "inranker"]


def format_queries(query, instruction=None):
    prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
    if instruction is None:
        instruction = (
            "Given a web search query, retrieve relevant passages that answer the query"
        )
    return f"{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"


def format_document(document):
    suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
    return f"<Document>: {document}{suffix}"


def weighted_rrf_fuse(runs, weights=None, rrf_k=60, save_text=False):
    """
    runs: list[list[Hit]] docids ordered best->worst
    weights: list[float] same length as runs, defaults to 1/len(runs) each
    save_text: Save the text field (relevant if this isn't the final step)
    """
    assert sum(weights) == 1.0, "Weights must sum to 1.0"
    if weights is None:
        weights = [1/len(runs)] * len(runs)
    scores = defaultdict(float)
    texts = defaultdict(str)
    for run, w in zip(runs, weights):
        for rank, hit in enumerate(run, start=1):
            scores[hit.docid] += w * (1.0 / (rrf_k + rank))
            texts[hit.docid] = hit.text if save_text else None

    fused = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    return [Hit(docid=docid, score=score, text=texts[docid]) for docid, score in fused]



class Reranker:
    def __init__(self, reranker_type, cross_encoder_name=None, device=DEVICE):
        self.reranker_type = reranker_type
        self.model_name = cross_encoder_name if cross_encoder_name else CROSS_ENCODER
        if self.reranker_type not in SUPPORTED_RERANKERS:
            raise ValueError(f"reranker_type must be in {SUPPORTED_RERANKERS}")
        if self.reranker_type == 'CE':
            self.model = CrossEncoder(self.model_name, device=DEVICE)

        elif self.reranker_type == 'QWEN_CE':
            self.model = CrossEncoder(self.model_name, device=DEVICE)
        # elif self.reranker_type == 'mxbai':
        #     self.model = MxbaiRerankV2("mixedbread-ai/mxbai-rerank-large-v2", device=device)
        #     print(device)
        #     self.model.to(device)
        # elif self.reranker_type == "inranker":
        #     self.model = T5Ranker(model_name_or_path="unicamp-dl/InRanker-3B", device=device)
        #     print(device)
        else:
            raise NotImplementedError("Type not implemented yet sry :(")

    def rerank(self, query, retrieval_candidates, max_weight=0.8):
        def _remove_whitespaces(text):
            WS_NEWLINES = re.compile(r"\s*\n\s*")
            WS_SPACES = re.compile(r"[ \t]+")
            text = WS_NEWLINES.sub(" ", text)
            text = WS_SPACES.sub(" ", text)
            return text
        def _collated_doc_score(scores, max_weight=0.8):
            return max(scores) * max_weight + (1 - max_weight) * (sum(scores) - max(scores)) / (len(scores) - 1) if len(scores) > 1 else max(scores)

        cleaned_docs = [_remove_whitespaces(doc.page_content) for doc in retrieval_candidates]
        per_doc_scores = defaultdict(list)

        if self.reranker_type == 'CE':
            pairs = [[query, cleaned_doc] for cleaned_doc in cleaned_docs]
            cross_scores = self.model.predict(pairs)
            for score, doc in zip(cross_scores, retrieval_candidates):
                per_doc_scores[doc.metadata['docid']].append(score)

        if self.reranker_type == 'mxbai':
            id2doc = {i:doc.metadata['docid'] for i, doc in enumerate(retrieval_candidates)}
            cross_scores = self.model.rank(query, cleaned_docs, return_documents=False)
            for score in cross_scores:
                per_doc_scores[id2doc[score.index]].append(score.score)

        if self.reranker_type == 'inranker':
            id2doc = {i: doc.metadata['docid'] for i, doc in enumerate(retrieval_candidates)}
            scores = self.model.get_scores(
                query=query,
                docs=cleaned_docs
            )
            # Scores are sorted in descending order (most relevant to least)
            # scores -> [0, 1]
            sorted_scores = sorted(zip(scores, cleaned_docs), key=lambda x: x[0], reverse=True)
            for i,(score,_) in sorted_scores:
                per_doc_scores[id2doc[i]].append(score)

        if self.reranker_type == 'QWEN_CE':
            task = "Given a web search query, retrieve relevant passages that answer the query"
            queries = [query]*len(cleaned_docs)
            pairs = [
                [format_queries(query, task), format_document(doc)]
                for query, doc in zip(queries,cleaned_docs)
            ]
            cross_scores = self.model.predict(pairs)
            for score, doc in zip(cross_scores, retrieval_candidates):
                per_doc_scores[doc.metadata['docid']].append(score)


        collated_doc_scores = {}
        for docid, scores in per_doc_scores.items():
            collated_doc_scores[docid] = _collated_doc_score(scores, max_weight=max_weight)
        ranked = sorted(collated_doc_scores.items(), key=lambda x: x[1], reverse=True)
        return [Hit(docid=doc[0], score=doc[1]) for doc in ranked]



class SearchEngine:
    def __init__(self):
        self.reader = IndexReader.from_prebuilt_index('robust04')
        self.searcher = LuceneSearcher.from_prebuilt_index('robust04')
        self.reranker = None

    def set_searcher(self, approach="qld", fb_terms=5, fb_docs=10, original_query_weight=0.8, mu=1000,
                     reranker_type='CE', reranker=CROSS_ENCODER):
        if approach=="qld":
            # Setting query likelihood with dirichlet prior
            self.searcher.set_qld(mu=mu)
            # Setting RM3 expanding the query, with a safe alpha
            self.searcher.set_rm3(fb_terms=fb_terms, fb_docs=fb_docs, original_query_weight=original_query_weight)
        elif approach=="bm25":
            self.searcher.set_bm25(k1=0.5, b=0.36)
            self.searcher.set_rm3(fb_terms=fb_terms,fb_docs=fb_docs,original_query_weight=original_query_weight)
        if reranker is not None:
            self.reranker = Reranker(reranker_type,cross_encoder_name=reranker)

    def get_top_k(self, query, k=5, clean=True, qid=None):
        """
        Get the top k ranked (full) documents using the searcher
        :param query: the query
        :param k: top results to retrieve (default: 5)
        :param clean: Whether to clean the retrieved docs and extract metadata (default: True)
        :param qid: query id
        :return:
        """
        context = []
        hits = self.searcher.search(query, k)
        # Get text from hits
        for hit in hits:
            doc = self.searcher.doc(hit.docid)
            raw_doc = doc.raw()
            if clean:
                cleaned_doc, doc_metadata = clean_robust(raw_doc)
                context.append(Hit(qid=qid, query=query, docid=hit.docid, score=hit.score, meta=doc_metadata, text=cleaned_doc))
            else:
                context.append(Hit(qid=qid, query=query, docid=hit.docid, score=hit.score, text=raw_doc))
        return context


    def multi_query_fuse(self, qid, topics_list, llm_query_fusion_weights, k=1000, rrf_k=60):
        top_ks = []
        assert len(llm_query_fusion_weights) == len(topics_list), "Weight & lists mismatch"
        if len(topics_list) > 1:
            for i, topics in enumerate(topics_list):
                if llm_query_fusion_weights[i] == 0:
                    continue
                query = topics[qid]
                top_ks.append(self.get_top_k(query, k, clean=True, qid=qid))
            top_k_fused = weighted_rrf_fuse(top_ks, weights=llm_query_fusion_weights, rrf_k=rrf_k, save_text=True)
            return top_k_fused
        else:
            return self.get_top_k(topics_list[0][qid], k, clean=True, qid=qid)

    def retrieve_rerank(self, query, hits, m=100, fusion_weights=None, rrf_k=60):
        top_m = hits[:m]
        passages_top_m = split_passages(top_m)
        if self.reranker:
            top_m_reranked = self.reranker.rerank(query, passages_top_m)
            top_m_fused_permutations = [weighted_rrf_fuse([top_m_reranked, top_m], weights=[1-fusion_weight,fusion_weight], rrf_k=rrf_k) for fusion_weight in fusion_weights]
            all_docs_reranked = [top_m_fused + hits[m:] for top_m_fused in top_m_fused_permutations]
            return all_docs_reranked
        else:
            return [hits]


    def search_and_write_trec_run(self, query, k, topic_id, run_tag, output_file, fusion_weights=None,
                                  query_weights=None,
                                  topics_lists=None,
                                  m=100,
                                  rrf_k_queries=9,
                                  rrf_k_reranker=60):
        if fusion_weights is None:
            fusion_weights = [0]
        assert k >= m, "initial retrieval k must be bigger-equal than fine reranker m"
        hits = self.multi_query_fuse(topic_id, topics_lists, query_weights, k=k, rrf_k=rrf_k_queries)  # Hits are score-sorted by default
        hits_per_fusion_weight = self.retrieve_rerank(query, hits, m, fusion_weights, rrf_k=rrf_k_reranker)
        for i, hits in enumerate(hits_per_fusion_weight):
            with open(f"Results/{output_file}_rrf_{fusion_weights[i]}.txt", "a", encoding="utf-8") as f:
                for rank, hit in enumerate(hits, start=1):
                    f.write(
                        f"{topic_id} Q0 {hit.docid} {rank} {hit.score:.6f} {run_tag}\n"
                    )



    def search_all_queries(self, topics_lists, k=1000, run_tag="run1", output_file="run.txt", m=100,
                           llm_query_fusion_weights=None,
                           rerank_fusion_weights=None,
                           rrf_k_queries=9,
                           rrf_k_reranker=60):
        """
        Search all queries according to topics list
        :param topics_lists: list of [(query id, query] for topic in topics. Each topic is taken form a .txt listing all queries.
        :param k: top results to retrieve (default: 1000)
        :param run_tag: name of run to write as the format
        :param output_file: name of outputfile (default: run.txt)
        :param m: reranking threshold (default: 100)
        :param llm_query_fusion_weights: list of fusion weights on multiple query ablations (default: [1,0...,0])
        :param rerank_fusion_weights: rrf weights to experiment with (default: 0)
        :param rrf_k_queries: Query fusion RRF constant
        :param rrf_k_reranker: RRF reranking constant
        """
        if rerank_fusion_weights is None:
            rerank_fusion_weights = [0]
        if llm_query_fusion_weights is None:
            llm_query_fusion_weights = [1]+[0]*(len(topics_lists) - 1)
        for qid, query in tqdm(topics_lists[0].items(), desc="Searching topics"):
            self.search_and_write_trec_run(query, k, qid, run_tag, output_file, m=m,
                                           fusion_weights=rerank_fusion_weights, query_weights=llm_query_fusion_weights,
                                           topics_lists=topics_lists, rrf_k_queries=rrf_k_queries, rrf_k_reranker=rrf_k_reranker)


cuda


### evaluate_map.py

In [10]:
from collections import defaultdict
from typing import Dict, List, Tuple, Iterable, Optional
from collections import defaultdict
import pandas as pd


def load_qrels(qrels_path: str) -> Dict[str, Dict[str, int]]:
    """
    qrels line format (TREC):
      qid  unused  docid  rel
    Example:
      301  0       FBIS3-10555  0
    """
    qrels = defaultdict(dict)
    with open(qrels_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            qid, _unused, docid, rel = line.split()[:4]
            qrels[qid][docid] = int(rel)
    return qrels

def load_run(run_path: str) -> Dict[str, List[str]]:
    """
    run line format (TREC run file):
      qid  Q0  docid  rank  score  tag
    We will sort by rank (int) to be safe.
    """
    run = defaultdict(list)  # qid -> list[(rank, docid)]
    with open(run_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            qid, _q0, docid, rank = parts[0], parts[1], parts[2], parts[3]
            run[qid].append((int(rank), docid))
    # sort by rank then keep docids
    out = {}
    for qid, lst in run.items():
        lst.sort(key=lambda x: x[0])
        out[qid] = [docid for _, docid in lst]
    return out


def average_precision(ranked_docids: List[str], qrels_for_q: Dict[str, int]) -> float:
    """
    AP(q) = average over precisions at ranks where a relevant document is found.
    Relevant is rel > 0. Unjudged docs are treated as non-relevant.
    Denominator is #relevant judged docs for that query.
    """
    rel_set = {docid for docid, rel in qrels_for_q.items() if rel > 0}
    if not rel_set:
        return 0.0

    hits = 0
    sum_prec = 0.0
    for i, docid in enumerate(ranked_docids, start=1):
        if docid in rel_set:
            hits += 1
            sum_prec += hits / i
    return sum_prec / len(rel_set)


def mean_average_precision(
    qrels: Dict[str, Dict[str, int]],
    run: Dict[str, List[str]],
    qids: Optional[Iterable[str]] = None
) -> Tuple[float, Dict[str, float]]:
    """
    Returns (MAP, per_query_AP_dict).
    If qids is None: evaluate intersection of qrels and run query ids.
    """
    if qids is None:
        eval_qids = sorted(set(qrels.keys()) & set(run.keys()), key=lambda x: int(x) if x.isdigit() else x)
    else:
        eval_qids = [str(q) for q in qids]

    ap_by_q = {}
    ap_values = []
    for qid in eval_qids:
        ap = average_precision(run.get(qid, []), qrels.get(qid, {}))
        ap_by_q[qid] = ap
        ap_values.append(ap)

    map_score = sum(ap_values) / len(ap_values) if ap_values else 0.0
    return map_score, ap_by_q

def get_map_by_paths(qrels_path, run_path):
    qrels = load_qrels(qrels_path)  # or "qrel301.txt"
    run = load_run(run_path)

    map_score, ap_by_q = mean_average_precision(qrels, run)
    return map_score

def load_topics(path):
    """
    Input format:
    qid<TAB>query text
    """
    topics = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            qid, query = line.split("\t", 1)
            topics[qid] = query
    return topics


def precision_at_k(ranked_docids, qrels_for_q, k):
    rel_set = {d for d, rel in qrels_for_q.items() if rel > 0}
    if not ranked_docids:
        return 0.0
    return sum(1 for d in ranked_docids[:k] if d in rel_set) / k


def recall_at_k(ranked_docids, qrels_for_q, k):
    rel_set = {d for d, rel in qrels_for_q.items() if rel > 0}
    if not rel_set:
        return 0.0
    return sum(1 for d in ranked_docids[:k] if d in rel_set) / len(rel_set)


def max_ap_at_k(qrels_for_q: Dict[str, int], k: int) -> float:
    rel_count = sum(1 for _, rel in qrels_for_q.items() if rel > 0)
    if rel_count == 0:
        return 0.0
    return min(rel_count, k) / rel_count


def first_relevant_rank_at_k(ranked_docids: List[str], qrels_for_q: Dict[str, int], k: int) -> int:
    rel_set = {d for d, rel in qrels_for_q.items() if rel > 0}
    if not rel_set:
        return 0  # no relevant docs judged for this query
    for i, docid in enumerate(ranked_docids[:k], start=1):
        if docid in rel_set:
            return i
    return 0  # none found within top-k


def reciprocal_rank_at_k(ranked_docids: List[str], qrels_for_q: Dict[str, int], k: int) -> float:
    r = first_relevant_rank_at_k(ranked_docids, qrels_for_q, k)
    return 1.0 / r if r > 0 else 0.0


def evaluate_run(
    qrels: Dict[str, Dict[str, int]],
    run: Dict[str, List[str]],
    name: str,                 # stored in df.attrs["name"]
    ks=range(100, 1001, 100),
) -> pd.DataFrame:
    """
    Long format: one row per k.
    Columns: k, MAP, P, Recall, MaxAP, FirstRel, MRR
    Run identity stored as df.attrs["name"].
    """
    rows = []

    for k in ks:
        ap_vals, p_vals, r_vals, max_ap_vals = [], [], [], []
        first_vals, rr_vals = [], []

        for qid, qrels_for_q in qrels.items():
            ranked = run.get(qid, [])[:k]

            ap_vals.append(average_precision(ranked, qrels_for_q))
            p_vals.append(precision_at_k(ranked, qrels_for_q, k))
            r_vals.append(recall_at_k(ranked, qrels_for_q, k))
            max_ap_vals.append(max_ap_at_k(qrels_for_q, k))

            fr = first_relevant_rank_at_k(ranked, qrels_for_q, k)
            first_vals.append(fr)
            rr_vals.append(1.0 / fr if fr > 0 else 0.0)

        rows.append({
            "k": int(k),
            "MAP": sum(ap_vals) / len(ap_vals),
            "P": sum(p_vals) / len(p_vals),
            "Recall": sum(r_vals) / len(r_vals),
            "MaxAP": sum(max_ap_vals) / len(max_ap_vals),
            "FirstRel": sum(first_vals) / len(first_vals),  # mean first relevant rank (0 if none)
            "MRR": sum(rr_vals) / len(rr_vals),            # mean reciprocal rank@k
        })

    df = pd.DataFrame(rows)
    df.attrs["name"] = name
    return df





## Actual Scripting

Local Script Dependencies

In [11]:
# from engine import SearchEngine
# from evaluate_map import *
# from optimizing import Optimize
import shutil
import os

Full (Current) Pipeline

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [2]:
topics = load_topics("Data/queriesROBUST.txt")
topics_expanded = load_topics("Data/chatExpandedQueries.txt")
topics_thes = load_topics("Data/chatQueries.txt")
qrels = load_qrels("Data/qrels_50_Queries")

In [3]:
def subset_topics(topics, qids_list):
    topics_subset = {
        k: v
        for k, v in topics.items()
        if int(k) in qids_list
    }
    return topics_subset

train_qids = list(range(301,351))
HARD_QUERIES =[309, 308, 338, 344, 348, 320, 328, 334, 303, 339] # From EDA, queries with low amounts of relevant documents.
topics_subset = subset_topics(topics, train_qids)
topics_expanded_subset = subset_topics(topics_expanded, train_qids)
topics_thes_subset = subset_topics(topics_thes, train_qids)

topics_hard = subset_topics(topics, HARD_QUERIES)
topics_expanded_hard = subset_topics(topics_expanded, HARD_QUERIES)
topics_thes_hard = subset_topics(topics_thes, HARD_QUERIES)


In [4]:
def compare_rerankers(topics_lists, qrels, reranker_type, rerankers, fusion_weights, query_fusion_weights, rrf_k_queries, rrf_k_reranker):
    os.makedirs(f"Results",exist_ok=True)
    shutil.rmtree("Results")
    os.makedirs(f"Results",exist_ok=True)
    for reranker in rerankers:
        save_name = reranker.split('/')[1].replace('-','_')
        print(f"Starting retrieval with reranker {reranker}")
        se = SearchEngine()
        se.set_searcher(approach="bm25",fb_terms=20, fb_docs=5, original_query_weight=0.6, mu=340, reranker_type=reranker_type, reranker=reranker)
        se.search_all_queries(topics_lists, k=1000, m=50, output_file=f"run_{save_name}", rerank_fusion_weights=fusion_weights, llm_query_fusion_weights=query_fusion_weights, rrf_k_queries=rrf_k_queries, rrf_k_reranker=rrf_k_reranker)
        for fusion_weight in fusion_weights:
            run = load_run(f"Results/run_{save_name}_rrf_{fusion_weight}.txt")
            map_score, ap_by_q = mean_average_precision(qrels, run)
            print(f"MAP for reranker {reranker} with rrf {fusion_weight}: {map_score}")
            if reranker is None:
                break

In [5]:
def fuse_rerankers(reranker_type, pre_fusion_weights, post_fusion_weights, topics, rrf_k=60):
    """
    Performs rrf fusion on all the variants of one reranker.
    """
    from processing import Hit
    from engine import weighted_rrf_fuse
    all_runs = []
    for fusion_weight in pre_fusion_weights:
        base = load_run(f"tofuse/run_{reranker_type}_rrf_{fusion_weight}.txt")
        run_hit_format = {}
        for qid in base.keys():
            run_hit_format[qid]=[Hit(docid=docid, score=0) for docid in base[qid]]
        all_runs.append(run_hit_format)

    fused = {}
    for qid in topics.keys():
        fused[qid] = weighted_rrf_fuse([run[qid] for run in all_runs], weights=post_fusion_weights, rrf_k=rrf_k)
        with open(f"Results/fused_fused_rrf_{rrf_k}.txt", "a", encoding="utf-8") as f:
            for rank, hit in enumerate(fused[qid], start=1):
                f.write(
                    f"{qid} Q0 {hit.docid} {rank} {hit.score:.6f} run42\n"
                )

    run = load_run(f"Results/fused_fused_rrf_{rrf_k}.txt")
    map_score, ap_by_q = mean_average_precision(qrels, run)
    print(f"MAP for fusion of rerankers of type {reranker_type} with rrf {rrf_k}: {map_score}")


In [6]:
# NOTE - fusion_weight = 0 means we take only the reranker, 1 means we take none of the reranker's inputs and it should be identical to pure lexical
# For pure lexical only, specify rerankers = [None]
# BEST RESULT IS rrf_k = 9
#cross-encoder/ms-marco-MiniLM-L-6-v2" 0.2/0.5
# mixedbread-ai/mxbai-rerank-large-v1 0.2
# tomaarsen/Qwen3-Reranker-0.6B-seq-cls ??

compare_rerankers([topics_subset, topics_expanded_subset, topics_thes], qrels,reranker_type='CE', rerankers= ["mixedbread-ai/mxbai-rerank-large-v1"], fusion_weights=[0, 0.2, 0.5, 0.7, 1], query_fusion_weights=[0.8, 0.2, 0.0], rrf_k_queries=9, rrf_k_reranker=60)
# for i in range(3,40,2):
#     compare_rerankers([topics_subset, topics_expanded_subset, topics_thes], qrels, [None], fusion_weights=[0, 0.2, 0.5, 0.7, 1], rrf_k=i, should_rerank_embedded=True)

#
# run = load_run(f"Results/run_CE_rrf_ariel_hits.txt")
# map_score, ap_by_q = mean_average_precision(qrels, run)
# print(f"MAP is: {map_score}")


Starting retrieval with reranker mixedbread-ai/mxbai-rerank-large-v1


Searching topics: 100%|██████████| 50/50 [09:56<00:00, 11.92s/it]


MAP for reranker mixedbread-ai/mxbai-rerank-large-v1 with rrf 0: 0.30438937954163175
MAP for reranker mixedbread-ai/mxbai-rerank-large-v1 with rrf 0.2: 0.30490620359703724
MAP for reranker mixedbread-ai/mxbai-rerank-large-v1 with rrf 0.5: 0.301047600529769
MAP for reranker mixedbread-ai/mxbai-rerank-large-v1 with rrf 0.7: 0.2934765819653454
MAP for reranker mixedbread-ai/mxbai-rerank-large-v1 with rrf 1: 0.28523687998539804


In [None]:
fuse_rerankers('CE', [0, 0.2, 0.5, 0.7, 1], [0.0,0.25,0.25,0.25,0.25], topics_subset, rrf_k=10)

Get LLM datasets and optimize

In [None]:
from processing import create_llm_generated_queries
create_llm_generated_queries("Data/LLM_outputs.txt")

In [None]:
def compare_llm_weights(queries_paths, qrels, rerankers, fusion_weights):
    topics_per_path = [load_topics(path) for path in queries_paths]
    shutil.rmtree("Results")
    os.makedirs(f"Results",exist_ok=True)
    for reranker in rerankers:
        print(f"Starting retrieval with reranker {reranker}")
        se = SearchEngine()
        se.set_searcher(approach="bm25",fb_terms=20, fb_docs=5, original_query_weight=0.6, mu=300, reranker="CE")
        se.search_all_queries(topics, k=1000, m=100, output_file=f"run_{reranker}", rerank_fusion_weights=0.2)
        for fusion_weight in fusion_weights:
            run = load_run(f"Results/run_{reranker}_rrf_{fusion_weight}.txt")
            map_score, ap_by_q = mean_average_precision(qrels, run)
            print(f"MAP for reranker {reranker} with rrf {fusion_weight}: {map_score}")
            if reranker is None:
                break
