## Installation

In [None]:
!pip install ranx

Collecting ranx
  Downloading ranx-0.3.19-py3-none-any.whl (99 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/99.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.2/99.2 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Collecting ir-datasets (from ranx)
  Downloading ir_datasets-0.5.7-py3-none-any.whl (337 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.9/337.9 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
Collecting orjson (from ranx)
  Downloading orjson-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (142 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.5/142.5 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting lz4 (from ranx)
  Downloading lz4-4.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m13.9 MB/s[0m eta [36m0:00

### Imports

In [None]:
import os
import gzip
import logging
from collections import defaultdict
from ranx import Qrels, Run, fuse, optimize_fusion, evaluate
from google.colab import drive
from ranx import fuse, optimize_fusion, evaluate
import numpy as np
from google.colab import drive
import os
from ranx import Run

## Some Preparation

### Get qrels

In [None]:
data_folder = 'trec2019-data'
os.makedirs(data_folder, exist_ok=True)

!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
!tar -xvzf queries.tar.gz

model_save_path = "/content/gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/finetuned_models/cross-encoder-distilbert-distilroberta-base-2024-05-12_07-39-41"

queries = {}
queries_filepath = os.path.join(data_folder, 'msmarco-test2019-queries.tsv.gz')
if not os.path.exists(queries_filepath):
    logging.info("Download " + os.path.basename(queries_filepath))
    !wget https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz -O {queries_filepath}

with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        queries[qid] = query

relevant_docs = defaultdict(lambda: defaultdict(int))
qrels_filepath = os.path.join(data_folder, '2019qrels-pass.txt')

if not os.path.exists(qrels_filepath):
    logging.info("Download " + os.path.basename(qrels_filepath))
    !wget https://trec.nist.gov/data/deep/2019qrels-pass.txt -O {qrels_filepath}

with open(qrels_filepath) as fIn:
    for line in fIn:
        qid, _, pid, score = line.strip().split()
        score = int(score)
        if score > 0:
            relevant_docs[qid][pid] = score

relevant_qid = [qid for qid in queries if len(relevant_docs[qid]) > 0]

qrels = Qrels(relevant_docs)

--2024-05-17 15:30:07--  https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz
Resolving msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)... 20.150.34.1
Connecting to msmarco.z22.web.core.windows.net (msmarco.z22.web.core.windows.net)|20.150.34.1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18882551 (18M) [application/gzip]
Saving to: ‘queries.tar.gz.1’


2024-05-17 15:30:08 (28.6 MB/s) - ‘queries.tar.gz.1’ saved [18882551/18882551]

queries.dev.tsv
queries.eval.tsv
queries.train.tsv


### Get Run Files

In [None]:
drive.mount('/content/gdrive')

base_path = "/content/gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/"

file_paths = [
    "finetuned_models/cross-encoder-cross-encoder-ms-marco-MiniLM-L-2-v2-2024-05-10_20-46-58ranking.run",
    "finetuned_models/cross-encoder-cross-encoder-ms-marco-TinyBERT-L-2-v2-2024-05-11_07-11-32ranking.run",
    "finetuned_models/cross-encoder-distilbert-distilroberta-base-2024-05-12_07-39-41ranking.run"
]

runs = []
for file_path in file_paths:
    full_path = os.path.join(base_path, file_path)
    run = Run.from_file(full_path, kind="trec")
    runs.append(run)

common_qids = set(qrels.qrels.keys()).intersection(*[run.keys() for run in runs])

filtered_qrels = {qid: {doc_id: score for doc_id, score in qrels[qid].items()} for qid in common_qids}
filtered_runs = [{qid: run[qid] for qid in common_qids} for run in runs]

qrels = Qrels(filtered_qrels)
runs = [Run(run) for run in filtered_runs]

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


### Get best_params

## Ensemble Methods

### 1. RRF (Reciprocal Rank Fusion, Rank-based Method)

In [None]:
# 1. RRF (Reciprocal Rank Fusion, Rank-based Method)

best_params = optimize_fusion(
    qrels=qrels,
    runs=runs,
    norm="min-max",
    method="rrf",
    metric="ndcg@100"
)

combined_run = fuse(
    runs=runs,
    norm="min-max",
    method="rrf",
    params=best_params,
)

metrics = evaluate(
    qrels=qrels,
    run=combined_run,
    metrics=["ndcg@10", "recall@100", "map@1000"]
)

print(metrics)

ndcg_10 = metrics['ndcg@10'] * 100
recall_100 = metrics['recall@100'] * 100
map_1000 = metrics['map@1000'] * 100

print(f"Queries: {len(common_qids)}")
print(f"NDCG@10: {ndcg_10:.2f}")
print(f"Recall@100: {recall_100:.2f}")
print(f"MAP@1000: {map_1000:.2f}")

Output()

{'ndcg@10': 0.7016068663669671, 'recall@100': 0.5144399762612633, 'map@1000': 0.4635443281120344}
Queries: 43
NDCG@10: 70.16
Recall@100: 51.44
MAP@1000: 46.35


### 2. BayesFuse

In [None]:
# 2. BayesFuse

best_params = optimize_fusion(
    qrels=qrels,
    runs=runs,
    norm="min-max",
    method="bayesfuse",
    metric="ndcg@100"
)

combined_run = fuse(
    runs=runs,
    norm="min-max",
    method="bayesfuse",
    params=best_params,
)

metrics = evaluate(
    qrels=qrels,
    run=combined_run,
    metrics=["ndcg@10", "recall@100", "map@1000"]
)

print(metrics)

ndcg_10 = metrics['ndcg@10'] * 100
recall_100 = metrics['recall@100'] * 100
map_1000 = metrics['map@1000'] * 100

print(f"Queries: {len(common_qids)}")
print(f"NDCG@10: {ndcg_10:.2f}")
print(f"Recall@100: {recall_100:.2f}")
print(f"MAP@1000: {map_1000:.2f}")

{'ndcg@10': 0.6998484272244461, 'recall@100': 0.5154443882913463, 'map@1000': 0.45978124243682267}
Queries: 43
NDCG@10: 69.98
Recall@100: 51.54
MAP@1000: 45.98


### 3. Condorcet

In [None]:
# 3. Condorcet (doesn't have optizie_fusion)

combined_run = fuse(
    runs=runs,
    norm="min-max",
    method="condorcet",
    # params=best_params,
)

metrics = evaluate(
    qrels=qrels,
    run=combined_run,
    metrics=["ndcg@10", "recall@100", "map@1000"]
)

print(metrics)

ndcg_10 = metrics['ndcg@10'] * 100
recall_100 = metrics['recall@100'] * 100
map_1000 = metrics['map@1000'] * 100

print(f"Queries: {len(common_qids)}")
print(f"NDCG@10: {ndcg_10:.2f}")
print(f"Recall@100: {recall_100:.2f}")
print(f"MAP@1000: {map_1000:.2f}")

{'ndcg@10': 0.7007382254405954, 'recall@100': 0.5156299963943269, 'map@1000': 0.46261835497475695}
Queries: 43
NDCG@10: 70.07
Recall@100: 51.56
MAP@1000: 46.26


### 4. CombSUM (doesn't have optizie_fusion)

In [None]:
# 4. CombSUM (doesn't have optizie_fusion)

combined_run = fuse(
    runs=runs,
    norm="min-max",
    method="sum",
    # params=best_params,
)

metrics = evaluate(
    qrels=qrels,
    run=combined_run,
    metrics=["ndcg@10", "recall@100", "map@1000"]
)

print(metrics)

ndcg_10 = metrics['ndcg@10'] * 100
recall_100 = metrics['recall@100'] * 100
map_1000 = metrics['map@1000'] * 100

print(f"Queries: {len(common_qids)}")
print(f"NDCG@10: {ndcg_10:.2f}")
print(f"Recall@100: {recall_100:.2f}")
print(f"MAP@1000: {map_1000:.2f}")

{'ndcg@10': 0.6907513253280211, 'recall@100': 0.5098510683461052, 'map@1000': 0.4547381128449681}
Queries: 43
NDCG@10: 69.08
Recall@100: 50.99
MAP@1000: 45.47


### 5. BordaFuse (doesn't have optizie_fusion)

In [None]:
# 5. BordaFuse (doesn't have optizie_fusion)

combined_run = fuse(
    runs=runs,
    norm="min-max",
    method="bordafuse",
)

metrics = evaluate(
    qrels=qrels,
    run=combined_run,
    metrics=["ndcg@10", "recall@100", "map@1000"]
)

print(metrics)

ndcg_10 = metrics['ndcg@10'] * 100
recall_100 = metrics['recall@100'] * 100
map_1000 = metrics['map@1000'] * 100

print(f"Queries: {len(common_qids)}")
print(f"NDCG@10: {ndcg_10:.2f}")
print(f"Recall@100: {recall_100:.2f}")
print(f"MAP@1000: {map_1000:.2f}")

{'ndcg@10': 0.7014228264342354, 'recall@100': 0.5129418478649971, 'map@1000': 0.462321250228213}
Queries: 43
NDCG@10: 70.14
Recall@100: 51.29
MAP@1000: 46.23
