In [None]:
import logging
import numpy as np
import pandas as pd
import os
import csv
from sklearn.metrics import average_precision_score
from sentence_transformers import SentenceTransformer, CrossEncoder, evaluation, InputExample, datasets

class CERerankingEvaluator:
    """
    This class is a modified version of the CERerankingEvaluator from SentenceTransformers, to include more evaluation metrics. 
    It evaluates a cross-encoder model to re-rank passages. The test data is a dictionary and each key has a sub-dictionary of the form: 
    {'query': '', 'positive': [], 'negative': []}. Query is the search query,
     positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.
    """
    def __init__(self, samples, mrr_at_k: int = 10, name: str = '', write_csv: bool = True): 
        self.samples = samples
        self.name = name
        self.mrr_at_k = mrr_at_k

        if isinstance(self.samples, dict):
            self.samples = list(self.samples.values())

        self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
        self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k), "MAP"] 
        self.write_csv = write_csv

    def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
        if epoch != -1:
            if steps == -1:
                out_txt = " after epoch {}:".format(epoch)
            else:
                out_txt = " in epoch {} after {} steps:".format(epoch, steps)
        else:
            out_txt = ":"

        all_mrr_scores = []
        all_ap_scores = []
        num_queries = 0
        num_positives = []
        num_negatives = []
        for instance in self.samples:
            query = instance['query']
            positive = list(instance['positive'])
            negative = list(instance['negative'])
            docs = positive + negative
            is_relevant = [True]*len(positive) + [False]*len(negative)

            if len(positive) == 0 or len(negative) == 0:
                continue

            num_queries += 1
            num_positives.append(len(positive))
            num_negatives.append(len(negative))

            model_input = [[query, doc] for doc in docs]
            pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
            pred_scores_argsort = np.argsort(-pred_scores)  #Sort in decreasing order
            
            # compute MRR score
            mrr_score = 0
            for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
                if is_relevant[index]:
                    mrr_score = 1 / (rank+1)
                    break

            all_mrr_scores.append(mrr_score)
            
            # compute AP
            all_ap_scores.append(average_precision_score(is_relevant, pred_scores.tolist()))

        mean_mrr = np.mean(all_mrr_scores)
        mean_ap = np.mean(all_ap_scores)

        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            output_file_exists = os.path.isfile(csv_path)
            with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
                writer = csv.writer(f)
                if not output_file_exists:
                    writer.writerow(self.csv_headers)

                writer.writerow([epoch, steps, mean_mrr, mean_ap])

        return {'mrr': mean_mrr, 'map': mean_ap}

In [None]:
# change depending on dataset
data_folder = usr_path+ '/cross-encoder/eli5/splits/'
output_folder = usr_path+ '/cross-encoder/eli5/results/' # to store results csv

In [None]:
test_samples = pd.read_csv(data_folder + 'test_samples.csv', converters={'positive': pd.eval, 'negative': pd.eval})

In [None]:
test_samples['negative'] = test_samples['negative'].apply(set)
test_samples['positive'] = test_samples['positive'].apply(set)
test_samples = test_samples.to_dict('index') 

In [None]:
evaluator = CERerankingEvaluator(test_samples)

In [None]:
# replace with path to model 
cross_encoder = CrossEncoder('./ms-marco-MiniLM-L-6-v2_eli5/',num_labels=1, max_length=512)

In [None]:
evaluator(cross_encoder, output_path=output_folder)