In [15]:
!pip install rank_bm25



In [28]:
from rank_bm25 import BM25Okapi
import numpy as np

In [29]:
class BM25Retriever:
    def __init__(self, doc_path, qry_path, rel_path):
        self.doc_set = self._load_documents(doc_path)
        self.qry_set = self._load_queries(qry_path)
        self.rel_set = self._load_relevance(rel_path)
        self.bm25 = self._initialize_bm25()

    def _load_documents(self, path):
        doc_set = {}
        doc_id, doc_text = "", ""

        with open(path) as f:
            lines = ""
            for l in f.readlines():
                lines += "\n" + l.strip() if l.startswith(".") else " " + l.strip()
            lines = lines.lstrip("\n").split("\n")

        for l in lines:
            if l.startswith(".I"):
                doc_id = int(l.split(" ")[1].strip())
            elif l.startswith(".X"):
                doc_set[doc_id] = doc_text.lstrip(" ")
                doc_id, doc_text = "", ""
            else:
                doc_text += l.strip()[3:] + " "

        return doc_set

    def _load_queries(self, path):
        qry_set = {}
        qry_id = ""

        with open(path) as f:
            lines = ""
            for l in f.readlines():
                lines += "\n" + l.strip() if l.startswith(".") else " " + l.strip()
            lines = lines.lstrip("\n").split("\n")

        for l in lines:
            if l.startswith(".I"):
                qry_id = int(l.split(" ")[1].strip())
            elif l.startswith(".W"):
                qry_set[qry_id] = l.strip()[3:]
                qry_id = ""

        return qry_set

    def _load_relevance(self, path):
        rel_set = {}

        with open(path) as f:
            for l in f.readlines():
                qry_id = int(l.lstrip(" ").strip("\n").split("\t")[0].split(" ")[0])
                doc_id = int(l.lstrip(" ").strip("\n").split("\t")[0].split(" ")[-1])
                if qry_id in rel_set:
                    rel_set[qry_id].append(doc_id)
                else:
                    rel_set[qry_id] = [doc_id]

        return rel_set

    def _initialize_bm25(self):
        corpus = list(self.doc_set.values())
        tokenized_corpus = [doc.split(" ") for doc in corpus]
        return BM25Okapi(tokenized_corpus)

    def retrieve_BM25(self, idx):
        query = self.qry_set[idx]
        rel_docs = self.rel_set[idx]

        tokenized_query = query.split(" ")
        scores = self.bm25.get_scores(tokenized_query)
        top_indices = np.argsort(scores)[::-1] + 1

        return top_indices

In [30]:
retriever = BM25Retriever('../dataset/CISI.ALL', '../dataset/CISI.QRY', '../dataset/CISI.REL')
retriever.retrieve_BM25(1)

array([  60,   24,  364, ...,  880, 1302, 1086])