In [1]:
import sys
import os

# Adding src path so we can import modules
src_path = os.path.abspath(os.path.join('..', '..',))

if src_path not in sys.path:
    sys.path.append(src_path)

In [12]:
from core.base_urls import SEARCHING_URL
from utils.loading import load_df, load_json
from utils.extracting import extract_doc_ids
from evaluation.precision import get_precision_at_k
from evaluation.recall import get_recall
import requests

dataset_path = os.path.join(src_path, '..', '..', 'Datasets', 'wikIR1k', 'training')

queries_path = os.path.join(dataset_path, 'queries.csv')
queries = load_df(queries_path)

qrels_path = os.path.join(dataset_path, 'qrels.json')
qrels = load_json(qrels_path)

recalls = []

for _, row in queries.head(10).iterrows():

    query_id = row['id_left']
    query_text = row['text_left']

    print(f'processing {query_id}')

    # get matching results
    search_text = query_text
    request_body = {
        'dataset_id': 2,
        'search_text': search_text,
    }
    response = requests.post(SEARCHING_URL, json = request_body)    

    if response.status_code == 200:
        # get retrieved docs
        retrieved_docs = response.json()['data']
        retrieved_ids = extract_doc_ids(retrieved_docs)

        print('Retrieved IDs')
        print(retrieved_ids)

        # get relevant docs
        relevant_ids = [qrel['doc_id'] for qrel in qrels if int(qrel['query_id']) == query_id]
        # convert ids to int
        relevant_ids = [int(id) for id in relevant_ids]

        print('Relevant IDs')
        print(relevant_ids)

        # calc precision@10
        precision = get_precision_at_k(retrieved_ids, relevant_ids, 10)
        print(f'Precision@10 is: {precision}')

        # calc recall@10
        recall = get_recall(retrieved_ids, relevant_ids)
        print(f'Recall is: {recall}')
        recalls.append(recall)
        print('--------------------------------')

# calc for all queries recall
avg_recall = sum(recalls) / len(recalls) if len(recalls) > 0 else 0
print(f'Average Recall is: {avg_recall}')

processing 123839
Retrieved IDs
[1727316, 817579, 1793430, 900201, 468338, 1072773, 93549, 905177, 1994126, 1491858, 540688, 2381018, 92466, 672746, 1646521, 191461, 1447188, 2013412, 1892619, 2070100, 501100, 545889, 1674970, 2202229, 201815, 719686, 2073847, 960235, 1231277, 2289161, 497760, 1097141, 1117174, 1548643, 186283, 2125265, 429670, 1134193, 664130, 180883, 877423, 1753411, 1122434, 1193848, 1047363, 268763, 34714, 1514328, 1167816, 1863649, 2049643, 1605887, 12139, 96990, 1994238, 999241, 1828915, 1381907, 600548, 2272306, 2133227, 269418, 1287537, 2440203, 1033576, 1358774, 476292, 1106460, 1875229, 516483, 822220, 1555809, 143009, 2251700, 1589191, 1104602, 1617012, 1834436, 2370154, 637857, 2059124, 246794, 228704, 1385195, 367717, 549702, 824825, 413246, 603292, 1657883, 57100, 806326, 799188, 806300, 1901730, 123839, 806075, 836567, 806263, 1230943]
Relevant IDs
[123839, 1793430, 806300, 806075, 836567, 806263]
Retrieved at 10
[1727316, 817579, 1793430, 900201, 468338