## Loading the queries & qrels

In [1]:
from constants import load_pickle_file, argsme_qrels_file ,argsme_queries_file
import numpy as np

"""
    Get the queries & qrels
    either calculate it online or get it from the pickle file
"""

# dict { query_id : list of {doc_id,rel} }
# qrels = get_qrels(use_cache=True)
qrels = load_pickle_file(argsme_qrels_file)

# list of { query_id , result : [ list of doc_id ] }
queries = load_pickle_file(argsme_queries_file)

qrels_docs_dict = {}
for qid, l in qrels.items():
    qrels_docs_dict.update({qid: [item['doc_id'] for item in l]})

loading : ../../pickle_files/argsme/argsme_qrels.pickle
loading : ../../pickle_files/argsme/argsme_queries.pickle


## Intersected documents

In [2]:
def calculate_relevant_count(retrieved_docs: list, query_docs: list) -> int:
    """
    Find the number of intersected documents between `retrieved_docs` and `query_docs`

    Args:
        retrieved_docs: Al list of doc_id that returned from matching
        query_docs: A list of doc_id belonging to a qid from qrel file

    Returns:
        Number of shared results
    """
    intersect_values = np.intersect1d(retrieved_docs, query_docs)
    matched_count = len(intersect_values)
    return matched_count


## Presision@K

In [3]:
def precision_at_k(relevant_docs, retrieved_docs, k=10) -> float | int:
    """
    Calculates Precision@k

    Args:
        relevant_docs: A dictionary mapping query IDs to a list of relevant document IDs.
        retrieved_docs: A dictionary mapping query IDs to a list of retrieved document IDs, ranked by relevance.
        k: The number of top retrieved documents to consider (default 10).

    Returns:
        The Precision value.
    """

    retrieved = retrieved_docs[:k]
    num_retrieved = len(retrieved)
    num_retrieved_relevant = calculate_relevant_count(retrieved, relevant_docs)
    return num_retrieved_relevant / num_retrieved if num_retrieved > 0 else 0


## Recall@K

In [4]:
def c_recall(relevant_docs, retrieved_docs, k=10):
    """
    Calculates Recall@k

    Args:
        relevant_docs: A dictionary mapping query IDs to a list of relevant document IDs.
        retrieved_docs: A dictionary mapping query IDs to a list of retrieved document IDs, ranked by relevance.
        k: The number of top retrieved documents to consider (default 10).

    Returns:
        A dictionary mapping query IDs to recall@k scores.
    """
    recall_scores = {}
    for query_id, relevant in relevant_docs.items():
        if query_id in retrieved_docs:
            retrieved = retrieved_docs[query_id][:k]  # Consider only top k retrieved documents
            num_relevant = len(relevant)
            num_retrieved_relevant = calculate_relevant_count(retrieved, relevant)
            recall_scores[query_id] = num_retrieved_relevant / num_relevant if num_relevant > 0 else 0
    return recall_scores


c_recall(qrels_docs_dict, queries, k=10)

{1: 0.08695652173913043,
 2: 0.02,
 3: 0.0,
 4: 0.022727272727272728,
 5: 0.020833333333333332,
 6: 0.045454545454545456,
 7: 0.04081632653061224,
 8: 0.0,
 9: 0.022222222222222223,
 10: 0.046511627906976744,
 11: 0.06666666666666667,
 12: 0.08,
 13: 0.0425531914893617,
 14: 0.04,
 15: 0.021739130434782608,
 16: 0.020833333333333332,
 17: 0.021739130434782608,
 18: 0.041666666666666664,
 19: 0.043478260869565216,
 20: 0.021739130434782608,
 21: 0.022727272727272728,
 22: 0.0425531914893617,
 23: 0.022222222222222223,
 24: 0.021739130434782608,
 26: 0.023255813953488372,
 27: 0.045454545454545456,
 28: 0.02040816326530612,
 29: 0.020833333333333332,
 30: 0.044444444444444446,
 31: 0.08333333333333333,
 32: 0.0,
 33: 0.06976744186046512,
 34: 0.08333333333333333,
 35: 0.021739130434782608,
 36: 0.08333333333333333,
 37: 0.06666666666666667,
 38: 0.08163265306122448,
 39: 0.023255813953488372,
 40: 0.0425531914893617,
 41: 0.06382978723404255,
 42: 0.10204081632653061,
 43: 0.163265306122

## Average precision

In [5]:
def get_rel_from_list(list_of_rel, doc_id) -> int:
    """Get the rel for a given doc_id from a list of {'doc_id': 'NCT00445783', 'rel': 1}, ..."""

    for rel in list_of_rel:
        if rel['doc_id'] == doc_id:
            return rel['rel']

    return 0




def average_precision(retrieved: list, relevant: list):
    p_sum = 0
    num_of_relevant = 0
    for i in range(10):
        k = i + 1

        # get the doc_id's for the current
        relevant_docs = [doc['doc_id'] for doc in relevant]

        p_at_k = precision_at_k(relevant_docs, retrieved, k)
        print(f'P@{k} : {p_at_k}')
        # get the k document id
        k_doc_id = retrieved[:k][-1]

        # get the rel(k)
        rel_at_k = get_rel_from_list(relevant, k_doc_id)

        if rel_at_k > 0:
            num_of_relevant += rel_at_k
        p_sum += p_at_k * rel_at_k

    return p_sum / num_of_relevant if num_of_relevant > 0 else 0


# MAP

In [6]:
def mean_average_precision(queries: dict, qrels: dict) -> float | int:
    ap_sum = 0
    for qid, query_results in queries.items():
        if qid == 10:
            continue
        print(f'query number {qid} : ')
        val = average_precision(queries[qid], qrels[qid])
        print(f'******* Average Precision : {val}')
        ap_sum += val
        print('-------------------------------------')
    return ap_sum / len(queries)

mean_average_precision(queries, qrels)

query number 1 : 
P@1 : 1.0
P@2 : 0.5
P@3 : 0.6666666666666666
P@4 : 0.5
P@5 : 0.4
P@6 : 0.3333333333333333
P@7 : 0.2857142857142857
P@8 : 0.375
P@9 : 0.4444444444444444
P@10 : 0.4
******* Average Precision : 0.6231481481481481
-------------------------------------
query number 2 : 
P@1 : 1.0
P@2 : 0.5
P@3 : 0.3333333333333333
P@4 : 0.25
P@5 : 0.2
P@6 : 0.16666666666666666
P@7 : 0.14285714285714285
P@8 : 0.125
P@9 : 0.1111111111111111
P@10 : 0.1
******* Average Precision : 1.0
-------------------------------------
query number 3 : 
P@1 : 0.0
P@2 : 0.0
P@3 : 0.0
P@4 : 0.0
P@5 : 0.0
P@6 : 0.0
P@7 : 0.0
P@8 : 0.0
P@9 : 0.0
P@10 : 0.0
******* Average Precision : 0
-------------------------------------
query number 4 : 
P@1 : 0.0
P@2 : 0.0
P@3 : 0.0
P@4 : 0.0
P@5 : 0.2
P@6 : 0.16666666666666666
P@7 : 0.14285714285714285
P@8 : 0.125
P@9 : 0.1111111111111111
P@10 : 0.1
******* Average Precision : 0.2
-------------------------------------
query number 5 : 
P@1 : 0.0
P@2 : 0.5
P@3 : 0.333333333

0.32562871139766186

## MRR

In [7]:
def mean_reciprocal_rank(relevant_docs, retrieved_docs):
    """
    Calculates Mean Reciprocal Rank (MRR)

    Args:
        relevant_docs: A dictionary mapping query IDs to a list of relevant document IDs.
        retrieved_docs: A dictionary mapping query IDs to a list of retrieved document IDs, ranked by relevance.

    Returns:
        The mean reciprocal rank (MRR) score.
    """
    sum_reciprocal_rank = 0
    num_queries = len(relevant_docs)
    for query_id, relevant in relevant_docs.items():
        rele = [item['doc_id'] for item in relevant]
        if query_id in retrieved_docs:
            retrieved = retrieved_docs[query_id]
            reciprocal_rank = 0
            for rank,retrieved_doc in enumerate(retrieved, 1):
                if retrieved_doc in rele:
                    reciprocal_rank = 1 / rank
                    break
            sum_reciprocal_rank += reciprocal_rank
    return sum_reciprocal_rank / num_queries


mean_reciprocal_rank(qrels,queries)


0.4221898283122773