In [1]:
import numpy as np

# MRR

In [2]:
def mean_reciprocal_rank(results):
    """
    Calculate the Mean Reciprocal Rank (MRR) for a list of ranked results.
    
    Parameters:
    results (list): A list of lists, where each inner list represents the ranked results for a query.
                    The first relevant result should appear first in each inner list.
    
    Returns:
    float: The Mean Reciprocal Rank (MRR) for the given results.
    """
    reciprocal_ranks = (np.asarray(result).nonzero()[0] for result in results)
    mrr = np.mean([1. / (rank[0] + 1) if rank.size else 0. for rank in reciprocal_ranks])
    return mrr

In [3]:
def convert_preds_to_results_array(predictions, correct_answers):
    """
    Convert predictions and correct answers to a results array suitable for input to the mean_reciprocal_rank function.
    
    Parameters:
    predictions (list of lists): A list of lists of predicted answers for each query.
    correct_answers (list): A list of correct answers for each query.
    
    Returns:
    numpy.ndarray: A 2D numpy array where each row represents the ranked results for a query,
                   with 1 indicating a correct prediction and 0 indicating an incorrect prediction.
    """
    results = []
    for preds, correct in zip(predictions, correct_answers):
        row = [1 if pred == correct else 0 for pred in preds]
        results.append(row)
    
    return np.array(results)

In [4]:
predictions = [['chunk_a', 'chunk_b', 'chunk_c', 'chunk_d', 'chunk_e'],
               ['chunk_b', 'chunk_a', 'chunk_c', 'chunk_d', 'chunk_e'],
               ['chunk_c', 'chunk_a', 'chunk_b', 'chunk_d', 'chunk_e']]

correct_answers = ['chunk_b', 'chunk_a', 'chunk_e']

# Convert predictions and correct answers to a results array
results_array = convert_preds_to_results_array(predictions, correct_answers)
results_array

array([[0, 1, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 0, 0, 0, 1]])

In [5]:
mrr = mean_reciprocal_rank(results_array)
print("Mean Reciprocal Rank:", mrr)

Mean Reciprocal Rank: 0.39999999999999997
