<a href="https://colab.research.google.com/github/annadymanus/IR-project/blob/main/model_metrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

Mounted at /gdrive


In [2]:
gold_standard_path = '/content/drive/Shareddrives/IRProject/validation/2019qrels-docs.txt'
model_predictions_path = '/content/drive/Shareddrives/IRProject/model_predictions/tf_idf_pointwise_preds.pickle'

In [3]:
!pip install pickle5
import pickle5 as pickle
import pandas as pd 


def rank_k_documents(query_results, k=100):
    """
    Rank the results of a query based on its descending model score and 
    cut at max length k.

    Args:
        query_results (list): List of tuples (docid, score) for a 
            certain queryid.
        k (int): cutpoint of the results, where metrics should be evaluated.
            Defaults to 100.
    
    Returns:
        list[str]: ranked list of docids, with max length k.
    """
    ranked_results = sorted(query_results, key=lambda tup: tup[1], reverse=True)
    ranked_docids = [result[0] for result in ranked_results]

    return ranked_docids if len(ranked_docids)<=k else ranked_docids[0:k-1]


def get_query_precision_rr_at_k(queryid, query_results, gold_standard, k=100):
    """
    Calculate precision at k and reciprocal rank at k for a certain queryid, 
    by comparing query_results with gold_standard.

    Args:
        queryid: Query ID to be evaluated.
        query_results (list): List of tuples (docid, score) for a 
            certain queryid.
        gold_standard (pandas.DataFrame): DataFrame with true relevant docids
            for each query.
        k (int): cutpoint of the results, where metrics should be evaluated.
            Defaults to 100.
    
    Returns:
        tuple[float, float]: Tuple (precision at k, reciprocal rank at k) 
            for the given queryid.
    """
    ranked_docids = rank_k_documents(
        query_results, 
        k=k,
    )
    
    if len(ranked_docids)==0:
        return 0, 0
    
    else:    
        actual_relevant_docs = gold_standard[
            (gold_standard['queryid']==int(queryid))
            & (gold_standard['rating']>0)
        ].sort_values(
            by='rating', 
            ascending=False,
        )['docid'].tolist()

        relevant_docs = set(ranked_docids).intersection(
            set(actual_relevant_docs)
        )

        precision_at_k = len(relevant_docs) / len(ranked_docids)

        reciprocal_rank_at_k = 0.0
        for doc_position in range(len(ranked_docids)):
            if ranked_docids[doc_position] in relevant_docs:
                reciprocal_rank_at_k = 1/(doc_position+1)
        
        return precision_at_k, reciprocal_rank_at_k


def get_model_metrics_per_query_at_k(model_predictions, gold_standard, k=100):
    """
    Calculate precision at k and reciprocal rank at k for all queryids
    in dict model_predictions.

    Args:
        model_predictions (dict): Dict of queryids and their lists of 
            documents retrieved by the model. Each key should by the queryid 
            in string format, and its value should be a list of tuples 
            (docid, score). Lists don't need to be ordered.
        gold_standard (pandas.DataFrame): DataFrame with true relevant docids
            for each query.
        k (int): cutpoint of the results, where metrics should be evaluated.
            Defaults to 100.
    
    Returns:
        list[dict]: List of records (queryid, precision at k, reciprocal rank 
            at k) for all queryids in model_predictions.
    """
    query_metrics = []
    for queryid in model_predictions.keys():
        query_results = model_predictions[str(queryid)]
        query_precision_at_k, rr_at_k = get_query_precision_rr_at_k(
            queryid=queryid,
            query_results=query_results,
            gold_standard=gold_standard,
            k=k,
        )
        query_metrics.append({
            'queryid': queryid,
            f'precision_at_{k}': query_precision_at_k,
            f'reciprocal_rank_at_{k}': rr_at_k,
        })
    
    return query_metrics


gold_standard = pd.read_csv(
    gold_standard_path, 
    sep=' ', 
    names=[
        'queryid', 
        'Q0', 
        'docid', 
        'rating',
    ],
)

with open(model_predictions_path, 'rb') as file:
    model_predictions = pickle.load(file)


model_query_metrics = pd.DataFrame(
    get_model_metrics_per_query_at_k(model_predictions, gold_standard)
)

model_query_metrics

Collecting pickle5
  Downloading pickle5-0.0.12-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (256 kB)
[?25l[K     |█▎                              | 10 kB 18.2 MB/s eta 0:00:01[K     |██▋                             | 20 kB 22.6 MB/s eta 0:00:01[K     |███▉                            | 30 kB 25.0 MB/s eta 0:00:01[K     |█████▏                          | 40 kB 24.6 MB/s eta 0:00:01[K     |██████▍                         | 51 kB 26.3 MB/s eta 0:00:01[K     |███████▊                        | 61 kB 28.0 MB/s eta 0:00:01[K     |█████████                       | 71 kB 27.0 MB/s eta 0:00:01[K     |██████████▎                     | 81 kB 27.4 MB/s eta 0:00:01[K     |███████████▌                    | 92 kB 28.8 MB/s eta 0:00:01[K     |████████████▉                   | 102 kB 23.6 MB/s eta 0:00:01[K     |██████████████                  | 112 kB 23.6 MB/s eta 0:00:01[K     |███████████████▍                | 122 kB 23.6 MB/s eta 0:00:01[K     |████████████████▋ 

Unnamed: 0,queryid,precision_at_100,reciprocal_rank_at_100
0,156493,0.272727,0.010417
1,1110199,0.222222,0.010101
2,1063750,0.242424,0.010204
3,130510,0.141414,0.01087
4,489204,0.343434,0.010101
5,573724,0.181818,0.012195
6,1133167,0.30303,0.010101
7,527433,0.181818,0.011364
8,1037798,0.010101,0.011905
9,915593,0.161616,0.010417


In [4]:
model_query_metrics.drop(columns=['queryid']).describe()

Unnamed: 0,precision_at_100,reciprocal_rank_at_100
count,43.0,43.0
mean,0.192624,0.011351
std,0.099151,0.004298
min,0.010101,0.010101
25%,0.111111,0.010101
50%,0.191919,0.010417
75%,0.282828,0.01099
max,0.343434,0.038462
