In [33]:
import os
import pandas as pd
import sys
import json
from google.protobuf.json_format import Parse

sys.path.insert(0, 'compiled_protobufs')
from taskmap_pb2 import TaskMap

queries_path = os.path.join(os.getcwd(), "datasets", "queries", "cooking_queries.csv")
qrels_path = os.path.join(os.getcwd(), 'datasets', 'qrles', 'qrles.qrles')
annotations_path = os.path.join(os.getcwd(), 'datasets', 'judgments', 'cooking-annotations.csv')


In [34]:
## save qrels files

with open(qrels_path, "w") as f:
    pd_annotations = pd.read_csv(annotations_path)
    lines = []
    # print(pd_annotations.head())
    for idx, annotation in pd_annotations.iterrows():
        q_id, doc_id, score = annotation["query-id"], annotation["doc-id"], annotation["relevance"]
        lines.append(f'{q_id} Q0 {doc_id} {score}\n')
    lines[-1] = lines[-1].replace("\n","")
    f.writelines(lines)
    

In [54]:
run_path = os.path.join(os.getcwd(), 'datasets', 'qrles')
taskmap_cooking_index_path = os.path.join(os.getcwd(), "indexes", "food", "system_index_sparse")


In [69]:
from pyserini.search import LuceneSearcher
sys.path.insert(0, './pygaggle')
from pygaggle.rerank.base import Query, Text, hits_to_texts
from pygaggle.rerank.transformer import MonoT5

config = [
    "bm25",
    "bm25+rm3",
    "bm25+t5",
    "bm25+rm3+t5",
]

In [79]:


cooking_queries = pd.read_csv(queries_path).iloc[:10]
reranker =  MonoT5()

def get_searcher(search_model):
    if search_model == "bm25" or search_model == "bm25+t5":
        searcher = LuceneSearcher(index_dir=taskmap_cooking_index_path)
        searcher.set_bm25(b=0.4, k1=0.9)
    if search_model == "bm25+rm3" or search_model == "bm25+rm3+t5":
        searcher = LuceneSearcher(index_dir=taskmap_cooking_index_path)
        searcher.set_bm25(b=0.4, k1=0.9)
        searcher.set_rm3(fb_terms=10, fb_docs=10, original_query_weight=0.5)
    return searcher

for model in config:
    lines = []
    for idx, query in cooking_queries.iterrows():
        print(f"Initialize searcher {model}")
        searcher = get_searcher(model)
        hits = searcher.search(q=query["target query"], k=50)
        if "t5" in model:
            hits = reranker.rerank(Query(query["target query"]), hits_to_texts(hits))
        for rank, hit in enumerate(hits):
            if type(hit) == Text:
                doc_json = json.loads(hit.text)
            else:
                doc_json = json.loads(hit.raw)
            taskmap_json = doc_json['recipe_document_json']
            taskmap = Parse(json.dumps(taskmap_json), TaskMap())
            doc_id = taskmap.taskmap_id
            line = f'query-{idx} Q0 {doc_id} {rank+1} {hit.score} bm25\n'
            lines.append(line)
    lines[-1] = lines[-1].replace("\n","")

    with open(os.path.join(run_path, model+".run"), "w") as f:
        f.writelines(lines)
        



Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+rm3
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+t5
Initialize searcher bm25+rm3+t5
Initialize searcher bm25+rm3+t5
Initialize searcher bm25+rm3+t5
Initialize searcher bm25+rm3+t5
Initialize searcher bm25+rm3+t5
Initialize searcher 

In [80]:
import ir_measures
from ir_measures import *

# qrles = ir_measures.read_trec_qrels('qrels/qrls.qrles')
# run = ir_measures.read_trec_run('qrels.run')

for model in config:
    print(os.path.join(run_path, model+".run"))
    qrles = ir_measures.read_trec_qrels(qrels_path)
    run = ir_measures.read_trec_run(os.path.join(run_path, model+".run"))
    accuracy = ir_measures.calc_aggregate([nDCG@3, Precision@3, Recall@3], qrles, run)
    print(f"{model}: {accuracy}")

/home/philip/task-search-quality/datasets/qrles/bm25.run
bm25: {P@3: 0.5666666666666667, R@3: 0.49499999999999994, nDCG@3: 0.5578509669245213}
/home/philip/task-search-quality/datasets/qrles/bm25+rm3.run
bm25+rm3: {P@3: 0.39999999999999997, R@3: 0.35, nDCG@3: 0.3953221626202607}
/home/philip/task-search-quality/datasets/qrles/bm25+t5.run
bm25+t5: {P@3: 0.2333333333333333, R@3: 0.19666666666666663, nDCG@3: 0.2092284198678045}
/home/philip/task-search-quality/datasets/qrles/bm25+rm3+t5.run
bm25+rm3+t5: {P@3: 0.3, R@3: 0.2633333333333333, nDCG@3: 0.2561562924700801}
