In [81]:
import json
import math
from tqdm import tqdm

In [82]:
import json

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        return data['data']
    
training_dataset_path = "datasets/training/parsed_data_final.json"
test_dataset = "results/bm25_results_training.json"

In [83]:
def get_ground_truth_articles_for_question(question):
    return question.get('ground_truth_documents_pid', [])

def get_ground_truth_snippets_for_question(question):
    return question.get('ground_truth_snippets', [])

def get_top_10_articles(question):
    return question.get('top_10_articles', [])

def get_top_10_snippets(question):
    return question.get('snippets', [])

In [84]:
# this computes MRR for single quesion
def compute_mrr(predicted_pids, ground_truth_pids):
    for i, pid in enumerate(predicted_pids):
        if pid in ground_truth_pids:
            return 1.0 / (i + 1)
    return 0.0

In [85]:
# this computes average precision for single question
def compute_average_precision(predicted_pids, ground_truth_pids):
    """
    Computes Average Precision based on:
    AP = sum(P@i * rel_i) / |relevant documents|
    """
    hits = 0
    score = 0.0
    for i, pid in enumerate(predicted_pids):
        if pid in ground_truth_pids:
            hits += 1
            score += hits / (i + 1)  # Precision at rank i
    return score / len(ground_truth_pids) if hits > 0 else 0.0

In [86]:
# This computes nDCG for single question
def compute_ndcg(predicted_pids, ground_truth_pids, k=10):
    """
    nDCG with binary relevance (1 if in ground truth, 0 otherwise)
    """
    dcg = 0.0
    for i, pid in enumerate(predicted_pids[:k]):
        if pid in ground_truth_pids:
            dcg += 1 / math.log2(i + 2)  # i+2 because ranks are 1-based

    ideal_dcg = sum(1 / math.log2(i + 2) for i in range(min(len(ground_truth_pids), k)))
    return dcg / ideal_dcg if ideal_dcg > 0 else 0.0

In [None]:
def evaluate_metrics_for_articles(ground_truth_data, predicted_data, k=10):
    mrr_total = 0.0
    map_total = 0.0
    ndcg_total = 0.0
    count = 0

    for gt_question, pred_question in tqdm(zip(ground_truth_data, predicted_data), desc="Processing questions..."):
        gt_pids = set(get_ground_truth_articles_for_question(gt_question))
        pred_pids = [val['pid'] for val in get_top_10_articles(pred_question)]

        if not gt_pids:
            continue

        mrr_total += compute_mrr(pred_pids, gt_pids)
        map_total += compute_average_precision(pred_pids, gt_pids)
        ndcg_total += compute_ndcg(pred_pids, gt_pids, k=k)
        count += 1

    if count == 0:
        return {'MRR': 0.0, 'MAP': 0.0, f"nDCG@{k}": 0.0}
        
    return {
        "MRR": mrr_total / count,
        "MAP": map_total / count,
        f"nDCG@{k}": ndcg_total / count,
        "count": count
    }

In [88]:
ground_truth_data = load_json(training_dataset_path)
predicted_data = load_json(test_dataset)

print("len gt: ", len(ground_truth_data))
print("len pred: ", len(predicted_data) )

results = evaluate_metrics_for_articles(ground_truth_data, predicted_data)

print(results)

len gt:  5390
len pred:  5390


Processing questions...: 3100it [00:00, 27441.93it/s]

GT: {'http://www.ncbi.nlm.nih.gov/pubmed/15858239', 'http://www.ncbi.nlm.nih.gov/pubmed/12239580', 'http://www.ncbi.nlm.nih.gov/pubmed/15829955', 'http://www.ncbi.nlm.nih.gov/pubmed/6650562', 'http://www.ncbi.nlm.nih.gov/pubmed/20598273', 'http://www.ncbi.nlm.nih.gov/pubmed/21995290', 'http://www.ncbi.nlm.nih.gov/pubmed/23001136', 'http://www.ncbi.nlm.nih.gov/pubmed/15617541', 'http://www.ncbi.nlm.nih.gov/pubmed/8896569'}
Pred: ['http://www.ncbi.nlm.nih.gov/pubmed/15858239', 'http://www.ncbi.nlm.nih.gov/pubmed/6650562', 'http://www.ncbi.nlm.nih.gov/pubmed/15829955', 'http://www.ncbi.nlm.nih.gov/pubmed/12239580', 'http://www.ncbi.nlm.nih.gov/pubmed/8896569', 'http://www.ncbi.nlm.nih.gov/pubmed/37522903', 'http://www.ncbi.nlm.nih.gov/pubmed/7719019', 'http://www.ncbi.nlm.nih.gov/pubmed/15956201', 'http://www.ncbi.nlm.nih.gov/pubmed/36742534', 'http://www.ncbi.nlm.nih.gov/pubmed/23001136']
GT: {'http://www.ncbi.nlm.nih.gov/pubmed/23787814', 'http://www.ncbi.nlm.nih.gov/pubmed/23212918', '

Processing questions...: 5390it [00:00, 27304.29it/s]

{'MRR': 0.823394440027094, 'MAP': 0.5506906661779183, 'nDCG@10': 0.7318113311402503, 'count': 5390}



