In [3]:
import pandas as pd
import bm25s
import Stemmer

from metrics import compute_ir_metrics


def preprocess(row: pd.Series, key: str):
    company, year, *_ = row["document_id"].split("-")
    processed = f"{company}-{year}, {" ".join(row[key].split()[:])}"
    return processed


AMOUNT_QUERIES = 10
USE_AUG = True

documents_df = pd.read_csv("../data/processed/documents.csv")
documents_df["processed"] = documents_df.apply(lambda x: preprocess(x, "document"), axis=1)

corpus = []

if USE_AUG:
    documents_aug_df = pd.read_csv("../data/processed/documents_aug.csv")

    for i, document in enumerate(documents_df["processed"]):
        for j in range(AMOUNT_QUERIES):
            document += " " + documents_aug_df["pseudo_query"][i + j]
        corpus.append(document)
else:
    corpus = documents_df["processed"].to_list()

stemmer = Stemmer.Stemmer("english")

# Tokenize the corpus and only keep the ids (faster and saves memory)
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)

# Create the BM25 model and index the corpus
retriever = bm25s.BM25()
retriever.index(corpus_tokens)

docids = documents_df["document_id"].to_list()
splits = ["train", "eval", "test"]

result_splits = {split: [] for split in splits}

for split in splits:
    query_df = pd.read_csv(f"../data/processed/{split}.csv")
    query_df["processed"] = query_df.apply(lambda x: preprocess(x, "question"), axis=1)
    queries = query_df["processed"].to_list()
    
    query_tokens = bm25s.tokenize(queries, stopwords="en", stemmer=stemmer)
    results, scores = retriever.retrieve(query_tokens, corpus=docids, k=10, n_threads=-1)
    
    metrics = compute_ir_metrics(results, query_df["document_id"].to_list())

    result_splits[split] = metrics

print(result_splits)

Split strings:   0%|          | 0/2789 [00:00<?, ?it/s]

Stem Tokens:   0%|          | 0/2789 [00:00<?, ?it/s]

BM25S Count Tokens:   0%|          | 0/2789 [00:00<?, ?it/s]

BM25S Compute Scores:   0%|          | 0/2789 [00:00<?, ?it/s]

Split strings:   0%|          | 0/6251 [00:00<?, ?it/s]

Stem Tokens:   0%|          | 0/6251 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/6251 [00:00<?, ?it/s]

Split strings:   0%|          | 0/883 [00:00<?, ?it/s]

Stem Tokens:   0%|          | 0/883 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/883 [00:00<?, ?it/s]

Split strings:   0%|          | 0/1147 [00:00<?, ?it/s]

Stem Tokens:   0%|          | 0/1147 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1147 [00:00<?, ?it/s]

{'train': {'hits@1': 0.6283794592865142, 'hits@10': 0.9609662454007358, 'mrr': 0.7507913019631146, 'ndcg@10': 0.8025851682864346}, 'eval': {'hits@1': 0.6013590033975085, 'hits@10': 0.9580973952434881, 'mrr': 0.7321145445720757, 'ndcg@10': 0.7876996750812651}, 'test': {'hits@1': 0.6190061028770706, 'hits@10': 0.955536181342633, 'mrr': 0.7430145583371389, 'ndcg@10': 0.7954192227718161}}
