In [3]:
import sys
import os

# Adding src path so we can import modules
src_path = os.path.abspath(os.path.join('..', '..',))

if src_path not in sys.path:
    sys.path.append(src_path)

In [4]:
from core.base_urls import SEARCHING_URL
from utils.loading import load_df, load_json
from utils.extracting import extract_doc_ids
from measures.precision import get_precision_at_k, get_avg_precision
from measures.recall import get_recall, get_avg_recall
from measures.map import get_ap, get_ap_at_k, get_map
from measures.mrr import get_reciprocal_rank, get_mrr
import requests

dataset_path = os.path.join(src_path, '..', '..', 'Datasets', 'wikIR1k', 'training')

queries_path = os.path.join(dataset_path, 'queries.csv')
queries = load_df(queries_path)

qrels_path = os.path.join(dataset_path, 'qrels.json')
qrels = load_json(qrels_path)

with open('dataset2.txt', 'w') as text_file:

    text_file.write('--------------------------------------------\n')
    text_file.write('           Second Dataset Results           \n')
    text_file.write('--------------------------------------------\n')

    precisions = []
    recalls = []
    aps_at_10 = []
    aps = []
    rrs = []

    for _, row in queries.iterrows():

        query_id = row['id_left']
        query_text = row['text_left']

        text_file.write(f'processing {query_id}\n')

        # get matching results
        search_text = query_text
        request_body = {
            'dataset_id': 2,
            'search_text': search_text,
        }
        response = requests.post(SEARCHING_URL, json = request_body)    

        if response.status_code == 200:
            # get retrieved docs
            retrieved_docs = response.json()['data']
            retrieved_ids = extract_doc_ids(retrieved_docs)

            text_file.write('Retrieved IDs\n')
            text_file.write(str(retrieved_ids) + '\n')

            # get relevant docs
            relevant_ids = [qrel['doc_id'] for qrel in qrels if int(qrel['query_id']) == query_id]
            # convert ids to int
            relevant_ids = [int(id) for id in relevant_ids]

            text_file.write('Relevant IDs\n')
            text_file.write(str(relevant_ids) + '\n')

            # calc precision@10
            precision = get_precision_at_k(retrieved_ids, relevant_ids, 10)
            text_file.write(f'Precision@10 is: {precision}\n')
            precisions.append(precision)

            # calc recall
            recall = get_recall(retrieved_ids, relevant_ids)
            text_file.write(f'Recall is: {recall}\n')
            recalls.append(recall)

            # calc ap@10
            ap_at_10 = get_ap_at_k(retrieved_ids, relevant_ids, 10)
            text_file.write(f'AP@10 is: {ap_at_10}\n')
            aps_at_10.append(ap_at_10)

            # calc ap
            ap = get_ap(retrieved_ids, relevant_ids)
            text_file.write(f'AP is: {ap}')
            aps.append(ap)

            # calc rr
            rr = get_reciprocal_rank(retrieved_ids, relevant_ids)
            text_file.write(f'RR is: {rr}')
            rrs.append(rr)

            text_file.write('--------------------------------------------')

    text_file.write('                Final Results               \n')
    text_file.write('--------------------------------------------\n')
    
    # calc avg precision@10
    avg_precision = get_avg_precision(precisions)
    print(f'Average Precision@10 is: {avg_precision}')
    text_file.write(f'Average Precision@10 is: {avg_precision}\n')
        
    # calc avg recall 
    avg_recall = get_avg_recall(recalls)
    print(f'Average Recall is: {avg_recall}')
    text_file.write(f'Average Recall is: {avg_recall}\n')

    # calc map@10
    map_at_10 = get_map(aps_at_10)
    print(f'MAP@10 is: {map_at_10}')
    text_file.write(f'MAP@10 is: {map_at_10}\n')

    # calc map
    map = get_map(aps)
    print(f'MAP is: {map}')
    text_file.write(f'MAP is: {map}\n')

    # calc mrr
    mrr = get_mrr(rrs)
    print(f'MRR is: {mrr}')
    text_file.write(f'MRR is: {mrr}\n')

Average Precision@10 is: 0.06080332409972299
Average Recall is: 0.36262852033953086
MAP@10 is: 0.13262045176642526
MAP is: 0.0958140434676988
MRR is: 0.16569886302925552
