In [1]:
import sys
import os

# Adding src path so we can import modules
src_path = os.path.abspath(os.path.join('..', '..',))

if src_path not in sys.path:
    sys.path.append(src_path)

In [2]:
from core.base_urls import SEARCHING_URL
from core.datasets import dataset_test_2
from utils.extracting import extract_doc_ids
from measures.precision import get_precision_at_k
from measures.recall import get_recall, get_avg_recall
from measures.map import get_ap_at_k, get_map
import requests

queries = dataset_test_2.queries_iter()
qrels = dataset_test_2.qrels_iter()

count = 0
recalls = []
aps = []

for query in queries:
    if count == 10:
        break

    query_id = query.query_id
    print(f'processing {query_id}')

    # get matching results
    search_text = query.text
    request_body = {
        'dataset_id': 1,
        'search_text': search_text,
    }
    response = requests.post(SEARCHING_URL, json = request_body)

    if response.status_code == 200:
        # get retrieved docs
        retrieved_docs = response.json()['data']
        retrieved_ids = extract_doc_ids(retrieved_docs)

        print('Retrieved IDs')
        print(retrieved_ids)

        # get relevant docs
        relevant_ids = [qrel.doc_id for qrel in qrels if qrel.query_id == query_id]
        # convert to int
        relevant_ids = [int(id) for id in relevant_ids]

        print('Relevant IDs')
        print(relevant_ids)

        # calc precision@10
        precision = get_precision_at_k(retrieved_ids, relevant_ids, 10)
        print(f'Precision@10 is: {precision}')

        # calc recall
        recall = get_recall(retrieved_ids, relevant_ids)
        print(f'Recall is: {recall}')
        recalls.append(recall)

        # calc ap@10
        ap = get_ap_at_k(retrieved_ids, relevant_ids, 10)
        print(f'AP@10 is: {ap}')
        aps.append(ap)
        print('-------------------------------------')

        count = count + 1

# calc for all queries' recalls
avg_recall = get_avg_recall(recalls)
print(f'Average Recall is: {avg_recall}')

# calc map
map = get_map(aps)
print(f'MAP@10 is: {map}')

processing 318
Retrieved IDs
[166713, 235401, 461212, 656307, 656349, 819359, 1596915, 1705696, 1718235, 2052128, 2205067, 4203560, 5148322, 6074403, 6118532, 7111108, 7740142, 8087084, 8149395, 8285812, 8442897, 9065270, 9470181, 10878224, 10962607, 12013019, 12082605, 12893070, 13262874, 14267188, 14363681, 15174186, 16232795, 16782848, 18953509, 20119327, 20288395, 25464283, 27420689, 28175961, 28193317, 28239331, 28362277, 28376975, 28487148, 28781093, 29147378, 31189253, 31213113, 31252141, 31520143, 32437202, 33136280, 33136641, 33713583, 34995029, 35097421, 35631568, 35687146, 37188077, 38043257, 38414781, 39755708, 41048976, 41957971, 42134828, 42421601, 43193290, 44104058, 44149815, 44809704, 45658670, 46363701, 47010123, 48675532, 49242277, 49550316, 50744636, 52010861, 54285054, 55873326, 56744456, 57702829, 57702962, 57703198, 58660738, 59514642, 61205946, 61268862, 63008243, 64230811, 64805112, 65000062, 65000925, 65021904, 66401973, 67032175, 68907371, 71538843, 71887283]