In [14]:
import os
import pandas as pd
from collections import defaultdict

in_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/unified_test/"
binary_qrel_path = os.path.join(in_dir, "qrels.test.tsv")

qid_to_relpids = defaultdict(set)
with open(binary_qrel_path) as fin:
    for line in fin:
        array = line.strip().split("\t")
        qid, pid = int(array[0]), int(array[2])
        qid_to_relpids[qid].add(pid)


In [2]:
root_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/"
product_catalogue_path = os.path.join(root_dir, "product_catalogue-v0.3.csv")

product_df = pd.read_csv(product_catalogue_path)
product_df.fillna('unknown', inplace=True)
product_df = product_df[product_df.product_locale=="us"]

train_df = pd.read_csv(os.path.join(root_dir, "train-v0.3.csv"))
train_df = train_df[train_df.query_locale=="us"]
query_to_qid = {query: qid+len(product_df) for qid, query in enumerate(train_df["query"].unique())}
train_df["query_id"] = train_df["query"].apply(lambda x: query_to_qid[x])

In [16]:
from tqdm import tqdm

pid_to_ivm = {pid: ivm for pid, ivm in enumerate(product_df.product_id.unique())}
ivm_to_pid = {ivm: pid for pid, ivm in pid_to_ivm.items()}
qid_to_reldata = defaultdict(list)
for i, row in tqdm(train_df.iterrows(), total=len(train_df)):
    qid, ivm, label = row.query_id, row.product_id, row.esci_label
    pid = ivm_to_pid[ivm]
    if qid not in qid_to_relpids:
        continue
    if label == "exact":
        assert pid in qid_to_relpids[qid], (qid, pid, qid_to_relpids[qid])
        qid_to_reldata[qid].append((pid, 3.0))
    elif label == "substitute":
        qid_to_reldata[qid].append((pid, 2.0))
    elif label == "complement":
        qid_to_reldata[qid].append((pid, 1.0))
    else:
        assert label == "irrelevant", label

100%|██████████| 1272626/1272626 [01:17<00:00, 16466.25it/s]


In [23]:
with open(os.path.join(in_dir, "grade_qrels.test.tsv"), "w") as fout:
    for qid, reldatas in qid_to_reldata.items():
        for pid, score in reldatas:
            fout.write(f"{qid}\tQ0\t{pid}\t{score}\n")

In [27]:
import sys 
sys.path.append("/home/jupyter/unity_jointly_rec_and_search/kgc-dr/")

from evaluation.retrieval_evaluator import RankingEvaluator

qrel_path = os.path.join(in_dir, "grade_qrels.test.tsv")
evaluator = RankingEvaluator(qrel_path)

expeirment_folder="/home/jupyter/unity_jointly_rec_and_search/experiments/amazon_esci/task2/dot-v5/"
ranking_path=os.path.join(expeirment_folder, "experiment_10-05_010529/runs/checkpoint_latest.test.query.small.run")

evaluator.compute_metrics(ranking_path)

{'MRR@10': 0.5340458020601843,
 'QueriesWithRelevant@10': 5284,
 'MRR@1000': 0.541832411671326,
 'QueriesWithRelevant@1000': 6684,
 'Recall@50': 0.4225834406835569,
 'Recall@1000': 0.7892604184552738,
 'nDCG@10': 0.33660811364223714,
 'nDCG@100': 0.43181201635523137,
 'MAP@1000': 0.2524931985012174,
 'QueriesRanked': 6814}

In [26]:
ranking_path=os.path.join("/home/jupyter/jointly_rec_and_search/experiments/bm25/amazon_esci/usearch.test.dr.run")
evaluator.compute_metrics(ranking_path)

{'MRR@10': 0.5388963794729906,
 'QueriesWithRelevant@10': 5115,
 'MRR@1000': 0.5465898001488014,
 'QueriesWithRelevant@1000': 6577,
 'Recall@50': 0.4136988847695458,
 'Recall@1000': 0.7269261664885075,
 'nDCG@10': 0.3436112498540656,
 'nDCG@100': 0.4232996162266999,
 'MAP@1000': 0.25679039880712223,
 'QueriesRanked': 6781}