In [2]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from cherche import retrieve
import sys
sys.path.append("../utils")
from data_utils import train_val_test_split_df
from file_utils import mkdir
from metrics import MRR, ACCURACY, RECALL, MAP

In [16]:
df_test = pd.read_json("../../../data-set_pre_processed/test/articles_test.json")

In [17]:
df_test["query"] = df_test["query"].apply(lambda x : x.replace("///"," "))
df_test["id"] = df_test["id"].apply(lambda x_list : [int(x) for x in x_list])

In [18]:
corpus_test = pd.read_json("../../../data-set_pre_processed/test/corpus_test.json", dtype={"id" : str})

In [19]:
corpus_test["id"] = corpus_test["id"].apply(int)

In [23]:
retriever = retrieve.BM25Okapi(key="id",
                               on=["first_sentence"],
                               documents=corpus_test.to_dict(orient="records"),
                               k=10)

In [24]:
df_test["query"][0:1].apply(lambda q : [r["id"] for r in retriever(q=q)])

0    [62278349818940349863543266908722591203062559004]
Name: query, dtype: object

In [25]:
metrics = {"ACCURACY": [],
           "MRR@10": [], "MRR@25": [],
           "RECALL@10": [], "RECALL@25": [], "RECALL@50": [], "RECALL@200": [],
           "MAP": []}
for query, reference_ids in tqdm(zip(df_test["query"], df_test["id"])) :
        retrieved_ids = retriever(q=query)
        retrieved_ids = [d["id"] for d in retrieved_ids]
        acc, correctness = ACCURACY(retrieved_ids, reference_ids,
                                    k=len(reference_ids),
                                    return_list=True)
        metrics["ACCURACY"].append(acc)
        metrics["MRR@10"].append(MRR(retrieved_ids, reference_ids,
                                          k=10))
        metrics["MRR@25"].append(MRR(retrieved_ids, reference_ids,
                                          k=25))
        metrics["RECALL@10"].append(RECALL(retrieved_ids, reference_ids,
                                                k=10))
        metrics["RECALL@25"].append(RECALL(retrieved_ids, reference_ids,
                                                k=25))
        metrics["RECALL@50"].append(RECALL(retrieved_ids, reference_ids,
                                                k=50))
        metrics["RECALL@200"].append(RECALL(retrieved_ids, reference_ids,
                                                 k=200))
        metrics["MAP"].append(MAP(retrieved_ids, reference_ids,
                                       k=len(reference_ids)))
        

for k, v in metrics.items():
    if len(v) > 0:
        metrics[k] = sum(v) / len(v)
for k, v in metrics.items():
    print(f"{str(k).ljust(12)} : {v * 100:.2f}")

198it [00:01, 152.24it/s]

ACCURACY     : 23.64
MRR@10       : 17.69
MRR@25       : 7.08
RECALL@10    : 23.72
RECALL@25    : 23.72
RECALL@50    : 23.72
RECALL@200   : 23.72
MAP          : 22.37



