## Installation

In [None]:
!pip install sentence-transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install pytrec_eval

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Imports

In [None]:
"""
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).

The query and the passage are passed simoultanously to a Transformer network. The network then returns
a score between 0 and 1 how relevant the passage is for a given query.

The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.

This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking.
"""
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import InputExample
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
import logging
from collections import defaultdict
import numpy as np
import sys
import pytrec_eval
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S')

## Evaluating preparation

### Initialize hyperparameters (e.g., batch size, etc)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
base_path = "./gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/"

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
!mkdir -p $base_path

## Evaluate the model


### Load the fine-tuned model that you trained using the previous notebook. You need to set the path of your own fine-tuned model here.

In [None]:
model_save_path = "./gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/finetuned_models/cross-encoder-cross-encoder-ms-marco-MiniLM-L-2-v2-2023-04-10_13-13-29/" #@param {type:"string"}

### Load data (For evaluation on TREC DL'19)

In [None]:
"""
This file evaluates CrossEncoder on the TREC 2019 Deep Learning (DL) Track: https://arxiv.org/abs/2003.07820

TREC 2019 DL is based on the corpus of MS Marco. MS Marco provides a sparse annotation, i.e., usually only a single
passage is marked as relevant for a given query. Many other highly relevant passages are not annotated and hence are treated
as an error if a model ranks those high.

TREC DL instead annotated up to 200 passages per query for their relevance to a given query. It is better suited to estimate
the model performance for the task of reranking in Information Retrieval.

Run:
python eval_cross-encoder-trec-dl.py cross-encoder-model-name

"""


data_folder = 'trec2019-data'
os.makedirs(data_folder, exist_ok=True)

#Read test queries
queries = {}
queries_filepath = os.path.join(data_folder, 'msmarco-test2019-queries.tsv.gz')
if not os.path.exists(queries_filepath):
    logging.info("Download "+os.path.basename(queries_filepath))
    util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz', queries_filepath)

with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        queries[qid] = query

#Read which passages are relevant
relevant_docs = defaultdict(lambda: defaultdict(int))
qrels_filepath = os.path.join(data_folder, '2019qrels-pass.txt')

if not os.path.exists(qrels_filepath):
    logging.info("Download "+os.path.basename(qrels_filepath))
    util.http_get('https://trec.nist.gov/data/deep/2019qrels-pass.txt', qrels_filepath)


with open(qrels_filepath) as fIn:
    for line in fIn:
        qid, _, pid, score = line.strip().split()
        score = int(score)
        if score > 0:
            relevant_docs[qid][pid] = score

# Only use queries that have at least one relevant passage
relevant_qid = []
for qid in queries:
    if len(relevant_docs[qid]) > 0:
        relevant_qid.append(qid)


# Read the top 1000 passages that are supposed to be re-ranked
passage_filepath = os.path.join(data_folder, 'msmarco-passagetest2019-top1000.tsv.gz')

if not os.path.exists(passage_filepath):
    logging.info("Download "+os.path.basename(passage_filepath))
    util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz', passage_filepath)



passage_cand = {}
with gzip.open(passage_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, pid, query, passage = line.strip().split("\t")
        if qid not in passage_cand:
            passage_cand[qid] = []

        passage_cand[qid].append([pid, passage])

logging.info("Queries: {}".format(len(queries)))


INFO:root:Queries: 200


## Prediction 

In [None]:
queries_result_list = []
run = {}
model = CrossEncoder(model_save_path, max_length=512)

for qid in tqdm.tqdm(relevant_qid):
    query = queries[qid]

    cand = passage_cand[qid]
    pids = [c[0] for c in cand]
    corpus_sentences = [c[1] for c in cand]

    cross_inp = [[query, sent] for sent in corpus_sentences]

    if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
        cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
    else:
        cross_scores = model.predict(cross_inp).tolist()

    cross_scores_sparse = {}
    for idx, pid in enumerate(pids):
        cross_scores_sparse[pid] = cross_scores[idx]

    sparse_scores = cross_scores_sparse
    run[qid] = {}
    for pid in sparse_scores:
        run[qid][pid] = float(sparse_scores[pid])

INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda
  0%|          | 0/43 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  2%|▏         | 1/43 [00:00<00:30,  1.40it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  5%|▍         | 2/43 [00:01<00:26,  1.53it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  7%|▋         | 3/43 [00:01<00:25,  1.54it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

  9%|▉         | 4/43 [00:02<00:23,  1.68it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 12%|█▏        | 5/43 [00:03<00:22,  1.66it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 14%|█▍        | 6/43 [00:03<00:21,  1.76it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 16%|█▋        | 7/43 [00:04<00:21,  1.70it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 19%|█▊        | 8/43 [00:04<00:19,  1.78it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 21%|██        | 9/43 [00:05<00:19,  1.78it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 23%|██▎       | 10/43 [00:05<00:18,  1.74it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 26%|██▌       | 11/43 [00:06<00:19,  1.66it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 28%|██▊       | 12/43 [00:07<00:18,  1.69it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 30%|███       | 13/43 [00:07<00:17,  1.71it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 33%|███▎      | 14/43 [00:08<00:16,  1.75it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 35%|███▍      | 15/43 [00:08<00:16,  1.72it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 37%|███▋      | 16/43 [00:09<00:16,  1.66it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 40%|███▉      | 17/43 [00:10<00:14,  1.74it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 42%|████▏     | 18/43 [00:10<00:15,  1.59it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 44%|████▍     | 19/43 [00:11<00:14,  1.66it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 47%|████▋     | 20/43 [00:11<00:13,  1.67it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 49%|████▉     | 21/43 [00:12<00:12,  1.74it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 51%|█████     | 22/43 [00:13<00:12,  1.67it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 53%|█████▎    | 23/43 [00:13<00:11,  1.70it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 56%|█████▌    | 24/43 [00:14<00:11,  1.70it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 58%|█████▊    | 25/43 [00:14<00:10,  1.78it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 60%|██████    | 26/43 [00:15<00:09,  1.78it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 63%|██████▎   | 27/43 [00:15<00:09,  1.67it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 65%|██████▌   | 28/43 [00:16<00:09,  1.64it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 67%|██████▋   | 29/43 [00:17<00:08,  1.65it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 70%|██████▉   | 30/43 [00:17<00:07,  1.72it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 72%|███████▏  | 31/43 [00:18<00:07,  1.68it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 74%|███████▍  | 32/43 [00:18<00:06,  1.73it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 79%|███████▉  | 34/43 [00:19<00:04,  2.20it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 84%|████████▎ | 36/43 [00:20<00:02,  2.56it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 86%|████████▌ | 37/43 [00:20<00:02,  2.39it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 88%|████████▊ | 38/43 [00:21<00:02,  2.09it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 91%|█████████ | 39/43 [00:21<00:02,  1.94it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 93%|█████████▎| 40/43 [00:22<00:01,  1.85it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 95%|█████████▌| 41/43 [00:23<00:01,  1.84it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 98%|█████████▊| 42/43 [00:23<00:00,  1.80it/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 43/43 [00:24<00:00,  1.77it/s]


## Evaluation

In [None]:
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10', 'recall_100', 'map_cut.1000'})
scores = evaluator.evaluate(run)

print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))
print("Recall@100: {:.2f}".format(np.mean([ele["recall_100"] for ele in scores.values()])*100))
print("MAP@1000: {:.2f}".format(np.mean([ele["map_cut_1000"] for ele in scores.values()])*100))

Queries: 43
NDCG@10: 68.37
Recall@100: 50.10
MAP@1000: 43.96


## Sorting candidate documents of each query based on their relevance score 

In [None]:
import operator
for qid in run.keys():
  run[qid] = sorted(run[qid].items(), key=operator.itemgetter(1), reverse = True)

## Storing ranking run file

In [None]:
ranking_lines = []
for qid in run.keys():
  for rank, did_pred_score in enumerate(run[qid]):
    did, pred_score = did_pred_score
    line = "{qid} Q0 {did} {rank} {pred_score} STANDARD".format(qid=qid, did=did, rank=rank, pred_score=str(pred_score))
    ranking_lines.append(line)

In [None]:
ranking_run_file_path = model_save_path + "ranking.run" 
f_w = open(ranking_run_file_path, "w+")
f_w.write("\n".join(ranking_lines))
f_w.close()

### Print the first three lines of the stored ranking run file

In [None]:
!head -n 3 $ranking_run_file_path

156493 Q0 1960255 0 4.394301414489746 STANDARD
156493 Q0 1960260 1 4.225707054138184 STANDARD
156493 Q0 2928707 2 4.06801700592041 STANDARD
