In [1]:
# Install dependencies
!pip install -q faiss-cpu sentence-transformers pytrec_eval torch tqdm

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m32.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone


In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"  #  Disable WandB globally

In [3]:
import os
import logging
import json
import torch
import faiss
import numpy as np
import pytrec_eval
from sentence_transformers import SentenceTransformer, losses
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset

In [4]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [5]:
# Enable logging
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
log = logging.getLogger(__name__)

In [6]:
# Set paths
DATA_PATH = '/kaggle/input/trec2023/TREC2023 Data/TREC-ToT'
MODEL_DIR = "/kaggle/working/AllMiniLM"
os.makedirs(MODEL_DIR, exist_ok=True)

In [7]:
# change accordingly
# CORPUS_PATH = '/kaggle/input/mini-wikipedia-dumps/MiniWikipediaDumps.jsonl'

TRAIN_QUERIES_PATH = os.path.join(DATA_PATH, 'TREC-TOT/train/queries.jsonl')
TRAIN_QREL_PATH = os.path.join(DATA_PATH, 'TREC-TOT/train/qrel.txt')

DEV_QUERIES_PATH = ('/kaggle/input/trec2023/TREC2023 Data/TREC-ToT/TREC-TOT/dev/queries.jsonl')
DEV_QREL_PATH =('/kaggle/input/trec2023/TREC2023 Data/TREC-ToT/TREC-TOT/dev/qrel.txt')

TEST_QUERIES_PATH = ('/kaggle/input/trec2023/TREC2023 Data/test/test/queries.jsonl')


In [8]:

# Load allMiniLM model
device = "cuda" if torch.cuda.is_available() else "cpu"
# embedding_model_n = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device, trust_remote_code=True) # Add trust_remote_code=True

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [9]:
device

'cpu'

In [10]:
def load_jsonl(file_path):
    """Load a JSONL file into a list."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

In [11]:
PROCESSED_CORPUS_PATH = "/kaggle/input/preprocessed-corpus/preprocessed_corpus.jsonl"

In [None]:
#  3 Load Queries & Qrels
# ===========================

def load_queries(query_file):
    """Load queries from a JSONL file."""
    queries = {}
    with open(query_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            queries[data["id"]] = data["text"]
    return queries

In [13]:
def load_qrels(qrel_file):
    """Load qrels (relevance judgments)."""
    qrels = {}
    with open(qrel_file, 'r', encoding='utf-8') as f:
        for line in f:
            qid, _, docid, rel = line.strip().split()
            if qid not in qrels:
                qrels[qid] = {}
            qrels[qid][docid] = int(rel)
    return qrels

In [14]:
# Load evaluation data

dev_queries = load_queries(os.path.join(DATA_PATH, 'TREC-TOT/dev/queries.jsonl'))
dev_qrels = load_qrels(os.path.join(DATA_PATH, 'TREC-TOT/dev/qrel.txt'))

In [None]:
#  Load the pre-saved FAISS index and document IDs
FAISS_INDEX_PATH = "/kaggle/input/faiss_index_all-minilm-l6-v2/other/default/1/faiss_index.bin"
DOC_IDS_PATH = "/kaggle/input/faiss_index_all-minilm-l6-v2/other/default/1/doc_ids.npy"

print(" Loading pre-saved FAISS index and document IDs...")
index = faiss.read_index(FAISS_INDEX_PATH)  # Load FAISS index
doc_ids = np.load(DOC_IDS_PATH, allow_pickle=True).tolist()  # Load document IDs

print(" FAISS index and document IDs loaded successfully!")

🔄 Loading pre-saved FAISS index and document IDs...
✅ FAISS index and document IDs loaded successfully!


In [None]:
# ===========================
#  4 Retrieve Top-K Documents (Before Fine-Tuning)
# ===========================

def retrieve_top_k(queries, top_k=1000):
    """Retrieve top-K documents using FAISS."""
    log.info(f"Retrieving top-{top_k} documents for queries...")
    results = {}

    for qid, query_text in tqdm(queries.items()):
        query_embedding = model.encode(query_text, convert_to_tensor=False).reshape(1, -1)
        distances, indices = index.search(query_embedding, top_k)
        # In retrieve_top_k() function (before fine-tuning)
        # retrieved_docs = {doc_ids[idx]: -float(dist) for idx, dist in zip(indices[0], distances[0])}  # ✅ Use negative distance
        retrieved_docs = {doc_ids[idx]: float(dist) for idx, dist in zip(indices[0], distances[0])}
        results[qid] = retrieved_docs

    return results

log.info("Retrieving before fine-tuning...")
before_results = retrieve_top_k(dev_queries)


  0%|          | 0/150 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 1/150 [00:00<00:28,  5.24it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  2%|▏         | 3/150 [00:00<00:15,  9.39it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  3%|▎         | 5/150 [00:00<00:12, 11.22it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  5%|▍         | 7/150 [00:00<00:12, 11.85it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  6%|▌         | 9/150 [00:00<00:12, 11.42it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  7%|▋         | 11/150 [00:01<00:12, 11.54it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  9%|▊         | 13/150 [00:01<00:10, 12.58it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 10%|█         | 15/150 [00:01<00:11, 12.16it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 11%|█▏        | 17/150 [00:01<00:10, 12.98it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 13%|█▎        | 19/150 [00:01<00:09, 13.86it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 14%|█▍        | 21/150 [00:01<00:09, 13.29it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 15%|█▌        | 23/150 [00:01<00:10, 12.67it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 17%|█▋        | 25/150 [00:02<00:09, 12.79it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 18%|█▊        | 27/150 [00:02<00:09, 13.66it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 19%|█▉        | 29/150 [00:02<00:08, 13.76it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 21%|██        | 31/150 [00:02<00:08, 14.01it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 22%|██▏       | 33/150 [00:02<00:07, 14.91it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 23%|██▎       | 35/150 [00:02<00:08, 13.70it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 25%|██▍       | 37/150 [00:02<00:08, 12.61it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 26%|██▌       | 39/150 [00:03<00:08, 13.62it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 27%|██▋       | 41/150 [00:03<00:07, 13.76it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 29%|██▊       | 43/150 [00:03<00:07, 13.63it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 30%|███       | 45/150 [00:03<00:08, 13.12it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 31%|███▏      | 47/150 [00:03<00:07, 13.39it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 33%|███▎      | 49/150 [00:03<00:07, 13.32it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 34%|███▍      | 51/150 [00:03<00:07, 12.89it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 35%|███▌      | 53/150 [00:04<00:07, 12.51it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 37%|███▋      | 55/150 [00:04<00:07, 12.35it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 38%|███▊      | 57/150 [00:04<00:07, 13.01it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 39%|███▉      | 59/150 [00:04<00:07, 12.47it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 41%|████      | 61/150 [00:04<00:06, 13.34it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 42%|████▏     | 63/150 [00:04<00:06, 13.12it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 43%|████▎     | 65/150 [00:05<00:06, 12.86it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 45%|████▍     | 67/150 [00:05<00:06, 13.39it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 46%|████▌     | 69/150 [00:05<00:06, 12.59it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 47%|████▋     | 71/150 [00:05<00:06, 12.65it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 49%|████▊     | 73/150 [00:05<00:06, 12.67it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 50%|█████     | 75/150 [00:05<00:06, 12.37it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 51%|█████▏    | 77/150 [00:05<00:05, 13.18it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 53%|█████▎    | 79/150 [00:06<00:05, 13.21it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 54%|█████▍    | 81/150 [00:06<00:05, 13.55it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 55%|█████▌    | 83/150 [00:06<00:04, 14.58it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 57%|█████▋    | 85/150 [00:06<00:04, 13.51it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 58%|█████▊    | 87/150 [00:06<00:04, 14.09it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 59%|█████▉    | 89/150 [00:06<00:04, 13.36it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 61%|██████    | 91/150 [00:07<00:04, 13.37it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 62%|██████▏   | 93/150 [00:07<00:04, 12.93it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 63%|██████▎   | 95/150 [00:07<00:04, 13.29it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 65%|██████▍   | 97/150 [00:07<00:03, 14.03it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 66%|██████▌   | 99/150 [00:07<00:03, 13.35it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 67%|██████▋   | 101/150 [00:07<00:03, 13.49it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 69%|██████▊   | 103/150 [00:07<00:03, 12.21it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 70%|███████   | 105/150 [00:08<00:03, 12.24it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 71%|███████▏  | 107/150 [00:08<00:03, 12.59it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 73%|███████▎  | 109/150 [00:08<00:03, 12.55it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 74%|███████▍  | 111/150 [00:08<00:02, 13.33it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 75%|███████▌  | 113/150 [00:08<00:02, 13.29it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 77%|███████▋  | 115/150 [00:08<00:02, 13.58it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 78%|███████▊  | 117/150 [00:09<00:02, 12.93it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 79%|███████▉  | 119/150 [00:09<00:02, 13.19it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 81%|████████  | 121/150 [00:09<00:02, 12.43it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 82%|████████▏ | 123/150 [00:09<00:02, 13.09it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 83%|████████▎ | 125/150 [00:09<00:01, 12.92it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 85%|████████▍ | 127/150 [00:09<00:01, 11.76it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 86%|████████▌ | 129/150 [00:10<00:01, 11.74it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 87%|████████▋ | 131/150 [00:10<00:01, 12.52it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 89%|████████▊ | 133/150 [00:10<00:01, 13.48it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 90%|█████████ | 135/150 [00:10<00:01, 13.11it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 91%|█████████▏| 137/150 [00:10<00:00, 13.54it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 93%|█████████▎| 139/150 [00:10<00:00, 13.76it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 94%|█████████▍| 141/150 [00:10<00:00, 12.64it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 95%|█████████▌| 143/150 [00:11<00:00, 12.32it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 97%|█████████▋| 145/150 [00:11<00:00, 13.05it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 98%|█████████▊| 147/150 [00:11<00:00, 13.34it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

 99%|█████████▉| 149/150 [00:11<00:00, 14.07it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 150/150 [00:11<00:00, 12.98it/s]


In [None]:
# Evaluation function
def evaluate(retrieved, qrels):
    common_qids = set(retrieved.keys()) & set(qrels.keys())
    if not common_qids:
        log.warning("No overlapping queries between results and qrels")
        return {}

    evaluator = pytrec_eval.RelevanceEvaluator(
    {qid: qrels[qid] for qid in common_qids},
    {'ndcg_cut_10',    # same as plain 'ndcg'
      'ndcg_cut_100',
      'ndcg_cut_1000',
      'recip_rank',
      'recall_3', 'recall_100', 'recall_1000',
      'success_3', 'success_100', 'success_1000'}
    )
    
    results = evaluator.evaluate({qid: retrieved[qid] for qid in common_qids})
    
    return {metric: np.mean([v[metric] for v in results.values()]) 
           for metric in results[next(iter(results))].keys()}



In [18]:
metrics = evaluate(before_results, dev_qrels)

In [19]:
# Print results
print("\nEvaluation Results Before Fine-Tuning:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")


Evaluation Results Before Fine-Tuning:
recip_rank: 0.0005
recall_3: 0.0000
recall_100: 0.0000
recall_1000: 0.3067
ndcg_cut_10: 0.0000
ndcg_cut_100: 0.0000
ndcg_cut_1000: 0.0322
success_3: 0.0000
success_100: 0.0000
success_1000: 0.3067


**Fine Tuning**

In [20]:
# Load datasets
log.info("Loading data...")
corpus = load_jsonl(PROCESSED_CORPUS_PATH)
train_queries = load_queries(TRAIN_QUERIES_PATH)
train_qrels = load_qrels(TRAIN_QREL_PATH)
dev_queries = load_queries(DEV_QUERIES_PATH)
dev_qrels = load_qrels(DEV_QREL_PATH)

In [21]:
# Training Dataset
from sentence_transformers import InputExample
class TrainDataset(Dataset):
    def __init__(self, queries, qrels, corpus):
        self.corpus = {doc["doc_id"]: doc["text"] for doc in corpus}
        self.examples = []
        
        for qid, docs in qrels.items():
            if qid not in queries:
                continue
            query_text = queries[qid]
            # Only keep relevant documents (rel >= 1)
            relevant_docs = [docid for docid, rel in docs.items() if rel >= 1]
            if not relevant_docs:
                continue
            # Create positive pairs
            for docid in relevant_docs:
                if docid in self.corpus:
                    self.examples.append(InputExample(
                        texts=[query_text, self.corpus[docid]],
                        label=1.0  # Binary label for MNRL
                    ))
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

In [22]:
# Prepare training data
train_dataset = TrainDataset(train_queries, train_qrels, corpus)
if len(train_dataset) == 0:
    raise ValueError("No valid training samples found!")

In [23]:
BATCH_SIZE = 32
EPOCHS = 50  # Actual training epochs
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)

# Fine-tune model
log.info(f"Starting fine-tuning for {EPOCHS} epochs...")
train_loss = losses.MultipleNegativesRankingLoss(model)

In [24]:
OUTPUT_PATH = "/kaggle/working/FINETUNE"
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [None]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=EPOCHS,
    warmup_steps=500,
    optimizer_params={'lr': 2e-5},
    show_progress_bar=True,
    output_path=OUTPUT_PATH,
    # use_wandb=False  # Disable WandB integration
)

# Save fine-tuned model
model.save(OUTPUT_PATH)
log.info("Model saved successfully.")

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [26]:
# Rebuild FAISS index with fine-tuned embeddings
def build_faiss_index(model, corpus, batch_size=256):
    log.info("Re-embedding corpus with fine-tuned model...")
    doc_ids = []
    embeddings = []
    
    for doc in tqdm(corpus, desc="Embedding documents"):
        doc_ids.append(doc["doc_id"])
        embeddings.append(model.encode(doc["text"], 
                                      batch_size=batch_size, 
                                      show_progress_bar=False))
    
    embeddings = np.vstack(embeddings).astype('float32')
    
    # Create FAISS index
    index = faiss.IndexFlatIP(embeddings.shape[1])
    index.add(embeddings)
    
    return index, doc_ids

In [27]:
# Create new index
new_index, new_doc_ids = build_faiss_index(model, corpus)
log.info("New FAISS index created with fine-tuned embeddings.")

Embedding documents: 100%|██████████| 231826/231826 [3:24:42<00:00, 18.87it/s]


In [28]:
# Save
faiss.write_index(new_index, "/kaggle/working/faiss_index_finetuned.bin")
np.save("/kaggle/working/doc_ids_finetuned.npy", np.array(new_doc_ids))

In [None]:
# Load the pre-saved FAISS index and document IDs
# FAISS_INDEX_PATH = "/kaggle/input/v43-output/faiss_index_finetuned.bin"
FAISS_INDEX_PATH ="/kaggle/working/faiss_index_finetuned.bin"
# DOC_IDS_PATH = "/kaggle/input/v43-output/doc_ids_finetuned.npy"
DOC_IDS_PATH ="/kaggle/working/doc_ids_finetuned.npy"

print(" Loading pre-saved FAISS index and document IDs...")
new_index = faiss.read_index(FAISS_INDEX_PATH)  # Load FAISS index
new_doc_ids = np.load(DOC_IDS_PATH, allow_pickle=True).tolist()  # Load document IDs

print(" FAISS index and document IDs loaded successfully!")

🔄 Loading pre-saved FAISS index and document IDs...
✅ FAISS index and document IDs loaded successfully!


In [30]:
# Retrieve with new index
def retrieve_top_k_finetuned(queries, index, doc_ids, top_k=1000):
    results = {}
    for qid, query_text in tqdm(queries.items(), desc="Retrieving"):
        query_embedding = model.encode(query_text, 
                                      convert_to_tensor=False,
                                      device=device).reshape(1, -1)
        distances, indices = index.search(query_embedding, top_k)
        results[qid] = {doc_ids[i]: float(distances[0][idx]) 
                       for idx, i in enumerate(indices[0])}
    return results

In [None]:
# Evaluate after fine-tuning
after_finetune_results = retrieve_top_k_finetuned(dev_queries, new_index, new_doc_ids, top_k=1000)
after_finetune_eval = evaluate(after_finetune_results, dev_qrels)  # Compute metrics

Retrieving:   0%|          | 0/150 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:   1%|▏         | 2/150 [00:00<00:12, 11.45it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:   3%|▎         | 4/150 [00:00<00:11, 12.63it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:   4%|▍         | 6/150 [00:00<00:10, 13.35it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:   5%|▌         | 8/150 [00:00<00:10, 13.46it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:   7%|▋         | 10/150 [00:00<00:11, 12.06it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:   8%|▊         | 12/150 [00:00<00:10, 12.63it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:   9%|▉         | 14/150 [00:01<00:10, 13.37it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  11%|█         | 16/150 [00:01<00:10, 13.22it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  12%|█▏        | 18/150 [00:01<00:10, 13.09it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  13%|█▎        | 20/150 [00:01<00:10, 12.69it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  15%|█▍        | 22/150 [00:01<00:10, 12.69it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  16%|█▌        | 24/150 [00:01<00:10, 12.15it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  17%|█▋        | 26/150 [00:02<00:09, 13.23it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  19%|█▊        | 28/150 [00:02<00:08, 13.82it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  20%|██        | 30/150 [00:02<00:08, 14.00it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  21%|██▏       | 32/150 [00:02<00:08, 14.62it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  23%|██▎       | 34/150 [00:02<00:08, 14.00it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  24%|██▍       | 36/150 [00:02<00:08, 13.02it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  25%|██▌       | 38/150 [00:02<00:08, 12.68it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  27%|██▋       | 40/150 [00:03<00:08, 13.71it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  28%|██▊       | 42/150 [00:03<00:07, 13.60it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  29%|██▉       | 44/150 [00:03<00:07, 13.31it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  31%|███       | 46/150 [00:03<00:08, 12.67it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  32%|███▏      | 48/150 [00:03<00:07, 12.92it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  33%|███▎      | 50/150 [00:03<00:07, 12.67it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  35%|███▍      | 52/150 [00:04<00:08, 11.88it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  36%|███▌      | 54/150 [00:04<00:07, 12.04it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  37%|███▋      | 56/150 [00:04<00:07, 12.17it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  39%|███▊      | 58/150 [00:04<00:07, 12.33it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  40%|████      | 60/150 [00:04<00:07, 12.23it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  41%|████▏     | 62/150 [00:04<00:06, 13.04it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  43%|████▎     | 64/150 [00:04<00:07, 12.20it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  44%|████▍     | 66/150 [00:05<00:06, 12.71it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  45%|████▌     | 68/150 [00:05<00:06, 12.45it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  47%|████▋     | 70/150 [00:05<00:06, 11.92it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  48%|████▊     | 72/150 [00:05<00:06, 11.48it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  49%|████▉     | 74/150 [00:05<00:06, 12.26it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  51%|█████     | 76/150 [00:05<00:06, 11.90it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  52%|█████▏    | 78/150 [00:06<00:05, 12.15it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  53%|█████▎    | 80/150 [00:06<00:05, 12.45it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  55%|█████▍    | 82/150 [00:06<00:05, 13.57it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  56%|█████▌    | 84/150 [00:06<00:04, 13.27it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  57%|█████▋    | 86/150 [00:06<00:04, 13.53it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  59%|█████▊    | 88/150 [00:06<00:04, 13.82it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  60%|██████    | 90/150 [00:07<00:04, 13.32it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  61%|██████▏   | 92/150 [00:07<00:04, 12.08it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  63%|██████▎   | 94/150 [00:07<00:04, 12.92it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  64%|██████▍   | 96/150 [00:07<00:03, 13.67it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  65%|██████▌   | 98/150 [00:07<00:03, 14.11it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  67%|██████▋   | 100/150 [00:07<00:03, 13.83it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  68%|██████▊   | 102/150 [00:07<00:03, 13.31it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  69%|██████▉   | 104/150 [00:08<00:03, 12.85it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  71%|███████   | 106/150 [00:08<00:03, 12.59it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  72%|███████▏  | 108/150 [00:08<00:03, 13.51it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  73%|███████▎  | 110/150 [00:08<00:03, 13.00it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  75%|███████▍  | 112/150 [00:08<00:02, 13.22it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  76%|███████▌  | 114/150 [00:08<00:02, 13.14it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  77%|███████▋  | 116/150 [00:08<00:02, 13.09it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  79%|███████▊  | 118/150 [00:09<00:02, 13.13it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  80%|████████  | 120/150 [00:09<00:02, 11.69it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  81%|████████▏ | 122/150 [00:09<00:02, 11.26it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  83%|████████▎ | 124/150 [00:09<00:02, 11.88it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  84%|████████▍ | 126/150 [00:09<00:02, 11.78it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  85%|████████▌ | 128/150 [00:10<00:01, 11.80it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  87%|████████▋ | 130/150 [00:10<00:01, 12.12it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  88%|████████▊ | 132/150 [00:10<00:01, 13.06it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  89%|████████▉ | 134/150 [00:10<00:01, 13.39it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  91%|█████████ | 136/150 [00:10<00:01, 13.34it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  92%|█████████▏| 138/150 [00:10<00:00, 13.82it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  93%|█████████▎| 140/150 [00:10<00:00, 13.38it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  95%|█████████▍| 142/150 [00:11<00:00, 12.29it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  96%|█████████▌| 144/150 [00:11<00:00, 12.29it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  97%|█████████▋| 146/150 [00:11<00:00, 12.36it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving:  99%|█████████▊| 148/150 [00:11<00:00, 12.31it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving: 100%|██████████| 150/150 [00:11<00:00, 12.79it/s]


In [32]:
# Print results
print("\nEvaluation Results After Fine-Tuning:")
for metric, value in after_finetune_eval.items():
    print(f"{metric}: {value:.4f}")


Evaluation Results After Fine-Tuning:
recip_rank: 0.0660
recall_3: 0.0600
recall_100: 0.2667
recall_1000: 0.4800
ndcg_cut_10: 0.0755
ndcg_cut_100: 0.1033
ndcg_cut_1000: 0.1295
success_3: 0.0600
success_100: 0.2667
success_1000: 0.4800


In [33]:
# Save results
with open("/kaggle/working/evaluation_results.json", "w") as f:
    json.dump({
        "before_finetune": metrics,
        "after_finetune": after_finetune_eval
    }, f)

# RERANKER

In [34]:
# # Re-ranking Strategies Implementation

import os
import json
import torch
import numpy as np
import pandas as pd
import logging
from tqdm import tqdm
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    T5ForConditionalGeneration,
    T5Tokenizer
)
from sentence_transformers import SentenceTransformer, util
import pytrec_eval
from concurrent.futures import ThreadPoolExecutor

In [35]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional
from tqdm import tqdm

In [36]:
# Initialize corpus_dict
corpus = load_jsonl(PROCESSED_CORPUS_PATH)
corpus_dict = {doc["doc_id"]: doc["text"] for doc in corpus}

In [37]:
# # Configure logging
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s - %(message)s')
# log = logging.getLogger(__name__)

In [38]:
# # Set device (CPU or GPU)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# log.info(f"Using device: {device}")

In [39]:
# Initialize paths and constants
RESULTS_DIR = "/kaggle/working/reranking_results"
os.makedirs(RESULTS_DIR, exist_ok=True)

In [40]:
# 1. MonoT5 Re-ranking Implementation
class MonoT5Reranker:
    def __init__(self, model_name="castorini/monot5-base-msmarco", batch_size=8):
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
        self.batch_size = batch_size
        self.query_prefix = "Query: "
        self.doc_prefix = " Document: "
        self.max_length = 512

    def compute_scores(self, query, docs):
        inputs = [self.query_prefix + query + self.doc_prefix + doc for doc in docs]
        all_scores = []
        
        for i in range(0, len(inputs), self.batch_size):
            batch = inputs[i:i+self.batch_size]
            encoded = self.tokenizer(batch, padding=True, truncation=True, 
                                    return_tensors="pt", max_length=self.max_length).to(device)
            with torch.no_grad():
                outputs = self.model.generate(**encoded)
            batch_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            batch_scores = [1.0 if out.startswith("true") else 0.0 for out in batch_outputs]
            all_scores.extend(batch_scores)
            
        return all_scores

    def rerank(self, queries, initial_results, top_k=100):
        """Rerank initial retrieval results using MonoT5."""
        reranked_results = {}
        
        for qid, query_text in tqdm(queries.items(), desc="MonoT5 reranking"):
            if qid not in initial_results:
                continue
                
            # Get initial results
            doc_ids = list(initial_results[qid].keys())[:top_k]
            docs = [corpus_dict.get(doc_id, "") for doc_id in doc_ids]
            
            # Skip if no documents to rerank
            if not docs:
                reranked_results[qid] = {}
                continue
                
            # Compute relevance scores
            scores = self.compute_scores(query_text, docs)
            
            # Create reranked results
            reranked = {doc_id: float(score) for doc_id, score in zip(doc_ids, scores)}
            reranked_sorted = {k: v for k, v in sorted(reranked.items(), key=lambda item: item[1], reverse=True)}
            reranked_results[qid] = reranked_sorted
            
        return reranked_results


In [41]:
# 2. CE-MiniLM-L6-v2 Re-ranking Implementation
class RobertaReranker:
    def __init__(self, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", batch_size=16):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
        self.batch_size = batch_size
        self.max_length = 512

    def compute_scores(self, query, docs):
        pairs = [[query, doc] for doc in docs]
        all_scores = []
        
        for i in range(0, len(pairs), self.batch_size):
            batch = pairs[i:i+self.batch_size]
            encoded = self.tokenizer(batch, padding=True, truncation=True, 
                                    return_tensors="pt", max_length=self.max_length).to(device)
            with torch.no_grad():
                outputs = self.model(**encoded)
                scores = torch.sigmoid(outputs.logits.squeeze(-1)).cpu().tolist()
                all_scores.extend(scores)
                
        return all_scores

    def rerank(self, queries, initial_results, top_k=100):
        """Rerank initial retrieval results using Cross-Encoder."""
        reranked_results = {}
        
        for qid, query_text in tqdm(queries.items(), desc="RoBERTa reranking"):
            if qid not in initial_results:
                continue
                
            # Get initial results
            doc_ids = list(initial_results[qid].keys())[:top_k]
            docs = [corpus_dict.get(doc_id, "") for doc_id in doc_ids]
            
            # Skip if no documents to rerank
            if not docs:
                reranked_results[qid] = {}
                continue
                
            # Compute relevance scores
            scores = self.compute_scores(query_text, docs)
            
            # Create reranked results
            reranked = {doc_id: float(score) for doc_id, score in zip(doc_ids, scores)}
            reranked_sorted = {k: v for k, v in sorted(reranked.items(), key=lambda item: item[1], reverse=True)}
            reranked_results[qid] = reranked_sorted
            
        return reranked_results

In [42]:
# 3. Reciprocal Rank Fusion Implementation
class ReciprocalRankFusion:
    def __init__(self, k=60):
        log.info(f"Initializing Reciprocal Rank Fusion with k={k}")
        self.k = k  # Constant for RRF formula
        
    def fuse_results(self, result_lists):
        """Fuse multiple result lists using Reciprocal Rank Fusion."""
        # Aggregate all query IDs
        all_qids = set()
        for results in result_lists:
            all_qids.update(results.keys())
            
        fused_results = {}
        
        for qid in all_qids:
            doc_scores = {}
            
            # Process each result list
            for results in result_lists:
                if qid not in results:
                    continue
                    
                # Get document IDs and their ranks
                doc_ids = list(results[qid].keys())
                
                # Compute RRF scores
                for rank, doc_id in enumerate(doc_ids, 1):
                    rrf_score = 1 / (rank + self.k)
                    if doc_id not in doc_scores:
                        doc_scores[doc_id] = 0
                    doc_scores[doc_id] += rrf_score
            
            # Sort by RRF score
            sorted_docs = {k: v for k, v in sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)}
            fused_results[qid] = sorted_docs
            
        return fused_results

In [43]:
# Function to run all re-ranking strategies
def run_reranking(queries, initial_results, qrels, top_k=100):
    """Run all re-ranking strategies and evaluate results."""
    results = {}
    
    # 1. MonoT5 Reranking
    monot5 = MonoT5Reranker()
    monot5_results = monot5.rerank(queries, initial_results, top_k)
    monot5_eval = evaluate(monot5_results, qrels)
    results["monot5"] = monot5_eval
    
    # Save MonoT5 results
    with open(os.path.join(RESULTS_DIR, "monot5_results.json"), "w") as f:
        json.dump(monot5_results, f)
    
    # 2. RoBERTa Reranking
    roberta = RobertaReranker()
    roberta_results = roberta.rerank(queries, initial_results, top_k)
    roberta_eval = evaluate(roberta_results, qrels)
    results["roberta"] = roberta_eval
    
    # Save RoBERTa results
    with open(os.path.join(RESULTS_DIR, "CE-MiniLM-L6-v2.json"), "w") as f:
        json.dump(roberta_results, f)
    
    # 3. Reciprocal Rank Fusion
    rrf = ReciprocalRankFusion()
    rrf_results = rrf.fuse_results([initial_results, monot5_results, roberta_results])
    rrf_eval = evaluate(rrf_results, qrels)
    results["rrf"] = rrf_eval
    
    # Save RRF results
    with open(os.path.join(RESULTS_DIR, "rrf_results.json"), "w") as f:
        json.dump(rrf_results, f)
    
    # Initial results evaluation for comparison
    initial_eval = evaluate(initial_results, qrels)
    results["initial"] = initial_eval
    
    return results

In [44]:
# Function to create comparison table
def create_comparison_table(results):
    """Create a comparison table of evaluation metrics."""
    # Extract metrics from results
    metrics = list(next(iter(results.values())).keys())
    systems = list(results.keys())
    
    # Create DataFrame
    data = []
    for system in systems:
        row = [system]
        for metric in metrics:
            row.append(results[system].get(metric, 0))
        data.append(row)
    
    # Create DataFrame
    df = pd.DataFrame(data, columns=["System"] + metrics)
    
    # Round values
    for col in df.columns[1:]:
        df[col] = df[col].round(4)
    
    return df

In [45]:
# Main function to execute re-ranking
def main(dev_queries_path, dev_qrels_path, initial_results):
    """Main function to execute re-ranking strategies."""
    # Load queries and qrels
    queries = load_queries(dev_queries_path)
    qrels = load_qrels(dev_qrels_path)
    
    log.info(f"Loaded {len(queries)} queries and {len(qrels)} qrels")
    
    # Run re-ranking
    results = run_reranking(queries, initial_results, qrels)
    
    # Create comparison table
    table = create_comparison_table(results)

In [46]:
initial_results = after_finetune_results
main(DEV_QUERIES_PATH, DEV_QREL_PATH, initial_results)

tokenizer_config.json:   0%|          | 0.00/1.89k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

MonoT5 reranking: 100%|██████████| 150/150 [3:34:05<00:00, 85.64s/it] 


tokenizer_config.json:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

RoBERTa reranking: 100%|██████████| 150/150 [30:08<00:00, 12.06s/it]


In [47]:
# !pip install -q transformers sentence-transformers pytrec_eval faiss-cpu rank_bm25 beir
