In [None]:
!git clone https://github.com/RegNLP/ObliQADataset.git
!pip install faiss-cpu pyserini==0.22.1
!wget https://gist.githubusercontent.com/kwang2049/63ed76eb0f4d79ca81caecdb06897bfb/raw/1d86978275d666dff904fba65a34ce3e71b3cf1d/bm25.py

Cloning into 'ObliQADataset'...
remote: Enumerating objects: 68, done.[K
remote: Counting objects: 100% (68/68), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 68 (delta 11), reused 47 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (68/68), 11.83 MiB | 13.52 MiB/s, done.
Resolving deltas: 100% (11/11), done.
--2024-10-03 16:51:35--  https://gist.githubusercontent.com/kwang2049/63ed76eb0f4d79ca81caecdb06897bfb/raw/1d86978275d666dff904fba65a34ce3e71b3cf1d/bm25.py
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10866 (11K) [text/plain]
Saving to: ‘bm25.py’


2024-10-03 16:51:35 (28.6 MB/s) - ‘bm25.py’ saved [10866/10866]



In [None]:
import os
import json
from typing import Dict

def load_qrels(docs_dir: str, fqrels: str) -> Dict[str, Dict[str, int]]:
    ndocs = 40
    docs = []
    for i in range(1, ndocs + 1):
        with open(os.path.join(docs_dir, f"{i}.json")) as f:
            doc = json.load(f)
            docs.append(doc)

    did2pid2id: Dict[str, Dict[str, str]] = {}
    for doc in docs:
        for psg in doc:
            did2pid2id.setdefault(psg["DocumentID"], {})
            assert psg["ID"] not in did2pid2id[psg["DocumentID"]]
            did2pid2id[psg["DocumentID"]].setdefault(psg["PassageID"], psg["ID"])

    with open(fqrels) as f:
        data = json.load(f)
    qrels = {}
    for e in data:
        qid = e["QuestionID"]
        for psg in e["Passages"]:
            qrels.setdefault(qid, {})
            pid = did2pid2id[psg["DocumentID"]][psg["PassageID"]]
            qrels[qid][pid] = 1
    return qrels

In [None]:
qrels = load_qrels("ObliQADataset/StructuredRegulatoryDocuments", "ObliQADataset/ObliQA_test.json")
with open("qrels", "w") as f:
    for qid, rels in qrels.items():
        for pid, rel in rels.items():
            line = f"{qid} Q0 {pid} {rel}"
            f.write(line + "\n")

In [None]:
from bm25 import BM25, Document, Query

bm25 = BM25()
ndocs = 40
collection = []
for i in range(1, ndocs + 1):
    with open(os.path.join("ObliQADataset/StructuredRegulatoryDocuments", f"{i}.json")) as f:
        doc = json.load(f)
        for psg in doc:
          collection.append(Document(psg["ID"], "", psg["PassageID"] + " " + psg["Passage"]))
bm25.index(iter(collection), len(collection), "./index")

Converting to pyserini format: 100%|██████████| 13732/13732 [00:00<00:00, 47408.95it/s]


In [None]:
queries = []
with open("ObliQADataset/ObliQA_test.json") as f:
    data = json.load(f)
    for e in data:
      queries.append(Query(e["QuestionID"], e["Question"]))
retrieved = bm25.search(
    queries=queries,
    index_path="./index",
    topk=10,
    batch_size=1,
)
with open("rankings.trec", "w") as f:
  for qid, hits in retrieved.items():
    for i, hit in enumerate(sorted(hits, key=lambda hit:hit.score, reverse=True)):
      line = f"{qid} 0 {hit.docid} {i+1} {hit.score} bm25"
      f.write(line + "\n")

SimpleSearcher class has been deprecated, please use LuceneSearcher from pyserini.search.lucene instead


Query batch: 100%|██████████| 2786/2786 [00:11<00:00, 249.56it/s]


In [None]:
!git clone https://github.com/usnistgov/trec_eval.git && cd trec_eval && make
!trec_eval/trec_eval -m recall.10 -m map_cut.10 ./qrels ./rankings.trec

Cloning into 'trec_eval'...
remote: Enumerating objects: 1142, done.[K
remote: Counting objects: 100% (291/291), done.[K
remote: Compressing objects: 100% (97/97), done.[K
remote: Total 1142 (delta 225), reused 237 (delta 188), pack-reused 851 (from 1)[K
Receiving objects: 100% (1142/1142), 755.17 KiB | 5.63 MiB/s, done.
Resolving deltas: 100% (769/769), done.
gcc -g -I.  -Wall -Wno-macro-redefined -DVERSIONID=\"10.0-rc2\"  -o trec_eval trec_eval.c formats.c meas_init.c meas_acc.c meas_avg.c meas_print_single.c meas_print_final.c gain_init.c get_qrels.c get_trec_results.c get_prefs.c get_qrels_prefs.c get_qrels_jg.c form_res_rels.c form_res_rels_jg.c form_prefs_counts.c utility_pool.c get_zscores.c convert_zscores.c measures.c  m_map.c m_P.c m_num_q.c m_num_ret.c m_num_rel.c m_num_rel_ret.c m_gm_map.c m_Rprec.c m_recip_rank.c m_bpref.c m_iprec_at_recall.c m_recall.c m_Rprec_mult.c m_utility.c m_11pt_avg.c m_ndcg.c m_ndcg_cut.c m_Rndcg.c m_ndcg_rel.c m_binG.c m_G.c m_rel_P.c m_succe