In [1]:
import os
import json
import time
import pandas as pd
from IPython.display import display

from lkae.utils.data_loading import pkl_dir, load_pkls, root_dir, AuredDataset
from lkae.retrieval.retrieve import get_retriever, retrieve_evidence, AuredDataset

datasets = load_pkls(pkl_dir)

# possilbe splits: train, dev, train_dev_combined
# (test, all_combined don't have "labels")
split = 'train_dev_combined'

dataset_split = f'English_{split}'
qrel_filename = f'{dataset_split}_qrels.txt'

dataset_variations_dict = datasets[dataset_split]
print(dataset_variations_dict.keys())

import pyterrier as pt
import pyterrier.io as ptio
import pyterrier.pipelines as ptpipelines
from ir_measures import R, MAP    

if not pt.started():
    pt.init()

dict_keys(['nopre-nam-bio', 'nopre-nam-nobio', 'nopre-nonam-bio', 'nopre-nonam-nobio', 'pre-nam-bio', 'pre-nam-nobio', 'pre-nonam-bio', 'pre-nonam-nobio'])


PyTerrier 0.10.1 has loaded Terrier 5.10 (built by craigm on 2024-08-22 17:33) and terrier-helper 0.0.8



In [2]:
# ground truth RQ1
golden = ptio.read_qrels(os.path.join(root_dir, 'data', qrel_filename))

# select a set of variations of the dataset
# these selected variations are selected for these reasons:
# - pre-nonam-nobio     ("raw" data, but preprocessed)
# - pre-nam-bio         (we would expect lexical retrieval to be best here)
# - nopre-nonam-nobio   ("raw" data)
# - nopre-nam-bio       (we would expect semantic retrieval to be best here, most information contained here)
selected_variations = ["pre-nonam-nobio", "pre-nam-bio", "nopre-nonam-nobio", "nopre-nam-bio"]


In [3]:
# load each config and construct its retriever

retrievers = {}

with open('config.json', 'r') as file:
    configs = json.load(file)

    for config in configs['configs']:
        retriever_label = get_retriever(**config)
        retrievers[config['retriever_method']] = retriever_label

retrievers

{'bm25': <lkae.retrieval.methods.bm25.BM25Retriever at 0x2965acbb1f0>,
 'tfidf': <lkae.retrieval.methods.tfidf.TFIDFRetriever at 0x2965acbb160>,
 'openai': <lkae.retrieval.methods.openai_embeddings.OpenAIRetriever at 0x2965246cf70>,
 'rerank-sbert-crossencoder': <lkae.retrieval.methods.rerank_sbert.CrossEncoderRerankRetriever at 0x29698702a40>,
 'rerank-nv-embed-v1': <lkae.retrieval.methods.rerank_bm25_nv.RerankingRetriever at 0x296987022f0>}

In [4]:
# then for every variation of the dataset in ds, run the experiment with each retriever and save the results
import pickle as pkl

out_dir = 'results'
data = []

for selected_variation in selected_variations:
    dataset: AuredDataset = dataset_variations_dict[selected_variation]
    for retriever_label in retrievers:
        start = time.time()

        run_filename = f'{out_dir}/{selected_variation}_{retriever_label}.pkl'

        # check if the file already exists from a previous run
        if os.path.exists(run_filename):
            print(f'found {run_filename}, loading from file')
            retrieved_data = pkl.load(open(run_filename, 'rb'))
        else:
            retrieved_data = retrieve_evidence(dataset[:], retrievers[retriever_label])
            pkl.dump(retrieved_data, open(run_filename, 'wb'))

        pred = pd.DataFrame([[*d, retriever_label] for d in retrieved_data], columns=['qid', 'docno', 'rank', 'score', 'name']) 

        eval = ptpipelines.Evaluate(pred, golden, metrics = [R@5,MAP], perquery=False)
        r5, meanap = [v for v in eval.values()]

        score = r5

        wall_time = time.time() - start

        print(f'result for retrieval run - R@5: {r5:.4f} MAP: {meanap:.4f} with config\tretriever: {retriever_label};\tds: {selected_variation}, took {wall_time:.2f} seconds')
        
        data.append({
            'R5': r5,
            'MAP': meanap,
            'Retrieval_Method': retriever_label, 
            'DS_Settings': selected_variation,
            'Time (s)': wall_time,
        })

# Convert the list of dictionaries to a DataFrame
df_retrieval = pd.DataFrame(data)

df_retrieval.to_csv(f'{out_dir}/df_retrieval.csv')
print(f'saved df to {out_dir}/df_retrieval.csv')

# Display the DataFrame
display(df_retrieval.sort_values(by='R5', ascending=False))

found results/pre-nonam-nobio_bm25.pkl, loading from file
result for retrieval run - R@5: 0.6345 MAP: 0.5679 with config	retriever: bm25;	ds: pre-nonam-nobio, took 0.03 seconds
found results/pre-nonam-nobio_tfidf.pkl, loading from file
result for retrieval run - R@5: 0.6022 MAP: 0.5045 with config	retriever: tfidf;	ds: pre-nonam-nobio, took 0.01 seconds
found results/pre-nonam-nobio_openai.pkl, loading from file
result for retrieval run - R@5: 0.6186 MAP: 0.5719 with config	retriever: openai;	ds: pre-nonam-nobio, took 0.01 seconds
found results/pre-nonam-nobio_rerank-sbert-crossencoder.pkl, loading from file
result for retrieval run - R@5: 0.5317 MAP: 0.4994 with config	retriever: rerank-sbert-crossencoder;	ds: pre-nonam-nobio, took 0.02 seconds
found results/pre-nonam-nobio_rerank-nv-embed-v1.pkl, loading from file
result for retrieval run - R@5: 0.7077 MAP: 0.6808 with config	retriever: rerank-nv-embed-v1;	ds: pre-nonam-nobio, took 0.01 seconds
found results/pre-nam-bio_bm25.pkl, loa

  attn_output = torch.nn.functional.scaled_dot_product_attention(


result for retrieval run - R@5: 0.5272 MAP: 0.4970 with config	retriever: rerank-sbert-crossencoder;	ds: nopre-nonam-nobio, took 86.30 seconds
result for retrieval run - R@5: 0.7145 MAP: 0.6713 with config	retriever: rerank-nv-embed-v1;	ds: nopre-nonam-nobio, took 570.06 seconds
result for retrieval run - R@5: 0.6378 MAP: 0.5651 with config	retriever: bm25;	ds: nopre-nam-bio, took 3.36 seconds
result for retrieval run - R@5: 0.6190 MAP: 0.5319 with config	retriever: tfidf;	ds: nopre-nam-bio, took 1.46 seconds
result for retrieval run - R@5: 0.6288 MAP: 0.5683 with config	retriever: openai;	ds: nopre-nam-bio, took 190.31 seconds
result for retrieval run - R@5: 0.5165 MAP: 0.4946 with config	retriever: rerank-sbert-crossencoder;	ds: nopre-nam-bio, took 147.46 seconds
result for retrieval run - R@5: 0.6796 MAP: 0.6505 with config	retriever: rerank-nv-embed-v1;	ds: nopre-nam-bio, took 574.42 seconds
saved df to results/df_retrieval.csv


Unnamed: 0,R5,MAP,Retrieval_Method,DS_Settings,Time (s)
14,0.714475,0.671333,rerank-nv-embed-v1,nopre-nonam-nobio,570.055045
4,0.707723,0.680776,rerank-nv-embed-v1,pre-nonam-nobio,0.014526
19,0.679609,0.650549,rerank-nv-embed-v1,nopre-nam-bio,574.421053
9,0.675846,0.647604,rerank-nv-embed-v1,pre-nam-bio,544.344645
5,0.645675,0.580895,bm25,pre-nam-bio,0.01452
15,0.637772,0.56511,bm25,nopre-nam-bio,3.355851
0,0.634475,0.567864,bm25,pre-nonam-nobio,0.028035
7,0.63395,0.575602,openai,pre-nam-bio,0.012514
12,0.63184,0.588819,openai,nopre-nonam-nobio,188.820764
17,0.628753,0.568263,openai,nopre-nam-bio,190.31397
