In [1]:
import re

In [2]:
def load_reference_from_stream(f):
    qids_to_relevant_passageids = {}
    for line in f:
        try:
            tokens = re.split('[\t\s]', line.strip())
            qid = int(tokens[0])
            pid = int(tokens[2])
            if qid not in qids_to_relevant_passageids:
                qids_to_relevant_passageids[qid] = []
            qids_to_relevant_passageids[qid].append(pid)
        except:
            raise IOError(f'"{line}" is not a valid format')
    return qids_to_relevant_passageids

def load_reference(path_to_reference):
    with open(path_to_reference, 'r', encoding='utf-8') as f:
        qids_to_relevant_passageids = load_reference_from_stream(f)
    return qids_to_relevant_passageids


In [3]:

def load_candidate_from_stream(f):
    qid_to_ranked_candidate_passages = {}
    for line in f:
        try:
            tokens = line.strip().split('\t')
            qid = int(tokens[0])
            pid = int(tokens[1])
            rank = int(tokens[2])
            if qid not in qid_to_ranked_candidate_passages:
                qid_to_ranked_candidate_passages[qid] = {}
            qid_to_ranked_candidate_passages[qid][rank] = pid
        except:
            raise IOError(f'"{line}" is not a valid format')
    for qid in qid_to_ranked_candidate_passages:
        ranked_passages = [pid for _, pid in sorted(qid_to_ranked_candidate_passages[qid].items())]
        qid_to_ranked_candidate_passages[qid] = ranked_passages
    return qid_to_ranked_candidate_passages

def load_candidate(path_to_candidate):
    with open(path_to_candidate, 'r', encoding='utf-8') as f:
        qid_to_ranked_candidate_passages = load_candidate_from_stream(f)
    return qid_to_ranked_candidate_passages

In [None]:
def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
    MaxMRRRank = 10
    total_queries = len(qids_to_relevant_passageids)
    MRR = 0.0
    MRR_at_10 = 0.0

    for qid in qids_to_relevant_passageids:
        if qid in qids_to_ranked_candidate_passages:
            target_pids = set(qids_to_relevant_passageids[qid])
            candidate_pids = qids_to_ranked_candidate_passages[qid]
            reciprocal_rank = 0.0
            reciprocal_rank_at_10 = 0.0

            for rank, pid in enumerate(candidate_pids, start=1):
                if pid in target_pids:
                    reciprocal_rank = 1.0 / rank
                    if rank <= MaxMRRRank:
                        reciprocal_rank_at_10 = reciprocal_rank
                    break
            MRR += reciprocal_rank
            MRR_at_10 += reciprocal_rank_at_10

    MRR /= total_queries
    MRR_at_10 /= total_queries

    return {'MRR@10': MRR_at_10}

In [5]:
def compute_metrics_from_files(path_to_reference, path_to_candidate):
    qids_to_relevant_passageids = load_reference(path_to_reference)
    qids_to_ranked_candidate_passages = load_candidate(path_to_candidate)

    metrics = compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
    return metrics

In [None]:
# Original relevance
path_to_reference = r"../data/msmarco_ans_small/qrels.dev.small.tsv"

# Documents re-ranked by respective models
bert_ranking = r"..data/run.monobert.dev.small.tsv/run.monobert.dev.small.tsv"
t5_ranking = r"../data/run.t5.dev.small.tsv/run.monobert.dev.small.tsv"

# Calculating MRR@10
metrics_bert = compute_metrics_from_files(path_to_reference, bert_ranking)
metrics_t5 = compute_metrics_from_files(path_to_reference, t5_ranking)

# Evaluation
print('BERT:')
for metric in sorted(metrics_bert):
    print(f'{metric}: {metrics_bert[metric]:.5f}')
print('##################### \n')

print('T5:')
for metric in sorted(metrics_t5):
    print(f'{metric}: {metrics_t5[metric]:.5f}')
print('#####################')

BERT: 
MRR@10: 0.31546
#####################
T5: 
MRR@10: 0.35633
#####################
