# Neural Reranker 

## Overview 

The Neural Reranker and Evaluation System enhances the baseline Information Retrieval (IR) system by introducing a CNN-based neural reranker model to improve the relevance of retrieved documents.

## Imports & Config

Here we specify all the necessary modules, the required files and paths:
- **Pretrained Embeddings**: Path to the pretrained embeddings file (txt format).
- **Corpus File**: Path to the corpus file (JSONL format).
- **Questions File**: Path to the questions file (JSONL format).
- **Questions Ranked (BM25)**: Path to the BM25-ranked file (JSONL format).
- **Training Data**: Path to the questions training file (JSONL format). 
- **BM25 Ranked Questions Training Data**: Path to the BM25-ranked training file (JSONL format).
- **Ouput File**: Path to save the reranked results (JSONL format).
- **Model Checkpoint**: Path to the trained model checkpoint (train new model).

In [6]:
import os
import sys
import torch
import torch.nn as nn
from datetime import datetime 
from pathlib import Path
from torch.utils.data import DataLoader

# Add parent directory to path to import from src
sys.path.append('..')

from src.model import CNNInteractionBasedModel, Tokenizer, PointWiseDataset
from src.utils import (
    load_pretrained_embeddings, 
    build_collate_fn, 
    get_all_doc_texts, 
    get_questions, 
)

PRETRAINED_EMB = "../data/glove.42B.300d.txt"
CORPUS_FILE = "../data/MEDLINE_2024_Baseline.jsonl"
QUESTIONS_FILE = "../data/questions.jsonl"
BM25_FILE = "../data/questions_bm25_ranked.jsonl"
TRAIN_Q_FILE = "../data/training_data.jsonl"
TRAIN_BM25_FILE = "../data/training_data_bm25_ranked.jsonl"
OUTPUT_FILE = "../output/final_ranked_questions.jsonl"
MODEL_CHECKPOINT = "../output/model"

Path(MODEL_CHECKPOINT).mkdir(parents=True, exist_ok=True)
print(f"Output model directory ready: {MODEL_CHECKPOINT}")

Output model directory ready: ../output/model


## Training Procedure

We train the reranking model in a pointwise classification setup, where the model learns to predict whether a retrieved document is **relevant (positive)** or **not relevant (negative)** to a given query.

**1. Input Data**
- Each **query** comes with a **set of retrieved documents** (from BM25).
- Each (query, document) pair is labeled:
    - **Positive (label=1)**: if the document is listed in the **gold-standard set** for that query.
    - **Negative (label=0)**: otherwise.

This produces training samples of the form:

```sh
(query_text, document_text, label)
```

**2. Negative/Positive Ratio**

- Since BM25 retrieves many more non-relevant documents than relevant ones, the dataset is **highly imbalanced**.
- To counteract this, we use a **negative sampling strategy**:
- For every **positive document**, we sample up to **k negatives** (e.g., 2× more negatives).
    - This yields an approximate **1:2 ratio** of positives to negatives.
    - This prevents the model from being biased toward always predicting "non-relevant".

**3. Negative Mining Approach**

We start with **BM25 retrieval** as a source of candidate documents:

- Positives are guaranteed to be included (if retrieved).
- Negatives are chosen from the **top BM25 results that are not in the gold standard**.
    - These negatives are **“hard negatives”**, because they were highly ranked by BM25 but are not relevant.
    - Training on these improves discrimination compared to random negatives.

In [18]:
def train_model(model, tokenizer, training_questions_file, training_ranked_file, corpus_file, device, batch_size=64, epochs=5, lr=1e-3):
    print("Loading training dataset...")

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    train_dataset = PointWiseDataset(training_questions_file, training_ranked_file, corpus_file, tokenizer, return_label=True)

    # Determine max lengths
    q_lens = [len(train_dataset[i]["question_token_ids"]) for i in range(len(train_dataset))]
    d_lens = [len(train_dataset[i]["document_token_ids"]) for i in range(len(train_dataset))]
    max_q_len = max(3, max(q_lens))
    max_d_len = max(3, max(d_lens))
    print(f"Max Q length: {max_q_len}, Max D length: {max_d_len}")

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        collate_fn=build_collate_fn(tokenizer, max_q_len, max_d_len),
        shuffle=True,
        pin_memory=(device.type == 'cuda')
    )

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_vloss = 1_000_000.

    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch + 1))
        model.train(True)

        running_loss = 0.
        last_loss = 0.

        for i, data in enumerate(train_loader):
            qids = data["question_token_ids"]
            dids = data["document_token_ids"]
            labels = data["label"]

            optimizer.zero_grad()
            scores = model(qids, dids)
            loss = loss_fn(scores, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if i % 1000 == 999:
                last_loss = running_loss / 1000
                print('  batch {} loss: {}'.format(i + 1, last_loss))
                running_loss = 0.

        running_vloss = 0.0
        # Set the model to evaluation mode, disabling dropout and using population
        # statistics for batch normalization.
        model.eval()

        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(train_loader): # validation loader is the same as the train loader
                qids = vdata["question_token_ids"]
                dids = vdata["document_token_ids"]
                labels = vdata["label"]

                voutputs = model(qids, dids)
                vloss = loss_fn(voutputs, labels)
                running_vloss += vloss.item()

        avg_vloss = running_vloss / (i + 1)
        print('LOSS train {} valid {}'.format(last_loss, avg_vloss))

        # Track best performance, and save the model's state
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            model_path = '../output/model/model_{}_{}.pt'.format(timestamp, epoch + 1)
            torch.save(model.state_dict(), model_path)

    return model

## Create Model

The CNNInteractionBasedModel is a PyTorch neural network model designed to rerank documents based on their relevance to a given query. It leverages convolutional layers to capture interactions between query and document embeddings.

**Key Components**:

- **Embedding Layer**: Converts token IDs into dense vectors. Supports loading pretrained embeddings (e.g., GloVe).
- **Convolutional Layer**: Captures local interactions between query and document embeddings.
- **Activation Function**: Applies ReLU activation to introduce non-linearity.
- **Pooling Layer**: Aggregates features using adaptive max pooling.
- **Fully Connected Layer**: Maps extracted features to a single relevance score.

In [None]:
# Cell 3: Tokenizer + Embeddings + Model
tokenizer = Tokenizer()
questions = get_questions(TRAIN_Q_FILE)
documents = get_all_doc_texts(TRAIN_Q_FILE, TRAIN_BM25_FILE, CORPUS_FILE)
tokenizer.fit(questions + documents)

if os.path.exists(PRETRAINED_EMB):
    print("Loading pretrained embeddings...")
    pretrained_embeddings = load_pretrained_embeddings(PRETRAINED_EMB, tokenizer, embedding_dim=300)
else:
    print("No pretrained embeddings found, using random init.")
    pretrained_embeddings = None

print("Vocabulary size is: ", tokenizer.vocab_size)

model = CNNInteractionBasedModel(vocab_size=tokenizer.vocab_size, pretrained_embeddings=pretrained_embeddings)


Loading pretrained embeddings...


## Training

### Training Setup 

We configure the model training with the following hyperparameters:

- **EPOCHS**: number of times that the model will see the entire training dataset.
- **BATCH_SIZE**: training samples are processed in batches query–document pairs.
- **LR**: learning rate controls how fast the optimizer updates the weights.

The training will automatically use a **GPU (CUDA)** if available, otherwise it will fall back to the **CPU**.

### Training the Model

We call the `train_model` function, which:

**1.** Loads the training dataset (queries, documents, labels).

**2.** Prepares data batches with proper padding.

**3.** Optimizes the model using **binary cross-entropy loss**, where:

- **positive samples (label = 1)**: documents in the gold standard set for a query.
- **negative samples (label = 0)**: retrieved documents not in the gold set.

**4.** Runs the training loop for the defined number of epochs.

**5.** Returns the trained model, ready for reranking/evaluation.

In [19]:
EPOCHS = 1
BATCH_SIZE = 64
LR = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

print("Training new model...")
model = train_model(
    model, tokenizer, TRAIN_Q_FILE, TRAIN_BM25_FILE,
    CORPUS_FILE, device, batch_size=BATCH_SIZE, epochs=EPOCHS, lr=LR
)

Using device: cpu
Training new model...
Loading training dataset...
Max Q length: 30, Max D length: 1785
EPOCH 1:
  batch 1000 loss: 0.13784657931514085
  batch 2000 loss: 0.13973148022405804
  batch 3000 loss: 0.1431128241531551
  batch 4000 loss: 0.14808657264336944
  batch 5000 loss: 0.14313516185060143
  batch 6000 loss: 0.14648264915309847
  batch 7000 loss: 0.1511464210636914
LOSS train 0.1511464210636914 valid 0.0958528913795787
