# BM25 Retrieval (Copied from RIRAG challenge)

## Imports

In [None]:
!git clone https://github.com/RegNLP/ObliQADataset.git
!pip install faiss-cpu pyserini==0.22.1
!wget https://gist.githubusercontent.com/kwang2049/63ed76eb0f4d79ca81caecdb06897bfb/raw/1d86978275d666dff904fba65a34ce3e71b3cf1d/bm25.py

In [None]:
import os
import json
from typing import Dict

## Data Loading

In [None]:
def load_qrels(docs_dir: str, fqrels: str) -> Dict[str, Dict[str, int]]:
    ndocs = 40
    docs = []
    for i in range(1, ndocs + 1):
        with open(os.path.join(docs_dir, f"{i}.json")) as f:
            doc = json.load(f)
            docs.append(doc)

    did2pid2id: Dict[str, Dict[str, str]] = {}
    for doc in docs:
        for psg in doc:
            did2pid2id.setdefault(psg["DocumentID"], {})
            assert psg["ID"] not in did2pid2id[psg["DocumentID"]]
            did2pid2id[psg["DocumentID"]].setdefault(psg["PassageID"], psg["ID"])

    with open(fqrels) as f:
        data = json.load(f)
    qrels = {}
    for e in data:
        qid = e["QuestionID"]
        for psg in e["Passages"]:
            qrels.setdefault(qid, {})
            pid = did2pid2id[psg["DocumentID"]][psg["PassageID"]]
            qrels[qid][pid] = 1
    return qrels

In [None]:
qrels = load_qrels("ObliQADataset/StructuredRegulatoryDocuments", "ObliQADataset/ObliQA_test.json")
with open("qrels", "w") as f:
    for qid, rels in qrels.items():
        for pid, rel in rels.items():
            line = f"{qid} Q0 {pid} {rel}"
            f.write(line + "\n")

## Building Index

In [None]:
from bm25 import BM25, Document, Query

bm25 = BM25()
ndocs = 40
collection = []
for i in range(1, ndocs + 1):
    with open(os.path.join("ObliQADataset/StructuredRegulatoryDocuments", f"{i}.json")) as f:
        doc = json.load(f)
        for psg in doc:
          collection.append(Document(psg["ID"], "", psg["PassageID"] + " " + psg["Passage"]))
bm25.index(iter(collection), len(collection), "./index")

## Querying

In [None]:
queries = []
with open("ObliQADataset/ObliQA_test.json") as f:
    data = json.load(f)
    for e in data:
      queries.append(Query(e["QuestionID"], e["Question"]))
retrieved = bm25.search(
    queries=queries,
    index_path="./index",
    topk=100,
    batch_size=1,
)
with open("bm25.trec", "w") as f:
  for qid, hits in retrieved.items():
    for i, hit in enumerate(sorted(hits, key=lambda hit:hit.score, reverse=True)):
      line = f"{qid} 0 {hit.docid} {i+1} {hit.score} bm25"
      f.write(line + "\n")

## Evaluation

In [None]:
!git clone https://github.com/usnistgov/trec_eval.git && cd trec_eval && make
!trec_eval/trec_eval -m recall.10 -m map_cut.10 ./qrels ./bm25.trec -q > ind_bm25_scores.csv