# Neural Reranker 

## Overview 

The Neural Reranker and Evaluation System enhances the baseline Information Retrieval (IR) system by introducing a CNN-based neural reranker model to improve the relevance of retrieved documents.

## Imports & Config

Here we specify all the necessary modules, the required files and paths:
- **Pretrained Embeddings**: Path to the pretrained embeddings file (txt format).
- **Corpus File**: Path to the corpus file (JSONL format).
- **Questions File**: Path to the questions file (JSONL format).
- **Questions Ranked (BM25)**: Path to the BM25-ranked file (JSONL format).
- **Training Data**: Path to the questions training file (JSONL format). 
- **BM25 Ranked Questions Training Data**: Path to the BM25-ranked training file (JSONL format).
- **Ouput File**: Path to save the reranked results (JSONL format).
- **Model Checkpoint**: Path to the trained model checkpoint (train new model).

In [None]:
import os
import sys
import torch
import torch.nn as nn
import ujson
from datetime import datetime 
from pathlib import Path
from torch.utils.data import DataLoader

# Add parent directory to path to import from src
sys.path.append('..')

import src.evaluation as ndcg 
from src.model import CNNInteractionBasedModel
from src.tokenizer import Tokenizer
from src.dataset import PointWiseDataset
from src.data_processing import load_tokenizer_config
from src.utils import (
    load_pretrained_embeddings, 
    build_collate_fn, 
    get_all_doc_texts, 
    get_questions, 
)

TOKENIZER_CONFIG = {
    'min_token_length': 3,
    'lowercase': True,
    'stem': True,
    'stopwords': None  # Can provide a set of stopwords
}

BATCH_SIZE = 64

OUTPUT_DIR = "../output"
PRETRAINED_EMB = "../data/glove.42B.300d.txt"
CORPUS_FILE = "../data/MEDLINE_2024_Baseline.jsonl"
QUESTIONS_FILE = "../data/questions.jsonl"
BM25_FILE = "../data/questions_bm25_ranked.jsonl"
TRAIN_Q_FILE = "../data/training_data.jsonl"
TRAIN_BM25_FILE = "../data/training_data_bm25_ranked.jsonl"
BM25_OUPUT_FILE = "../output/ranked_questions.jsonl"
OUTPUT_FILE = "../output/final_ranked_questions.jsonl"
MODEL_CHECKPOINT = "../output/model"

Path(MODEL_CHECKPOINT).mkdir(parents=True, exist_ok=True)
print(f"Output model directory ready: {MODEL_CHECKPOINT}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


In [None]:
# Tokenizer 

config_path = os.path.join(OUTPUT_DIR, "index/tokenizer_config.msgpack")

if os.path.exists(config_path):
    tokenizer_config = load_tokenizer_config(config_path)
    print("Loaded tokenizer configuration:")
    for key, value in tokenizer_config.items():
        print(f"  {key}: {value}")
else:
    tokenizer_config = TOKENIZER_CONFIG
    print("Using default tokenizer configuration")

tokenizer = Tokenizer(
    tokenizer_config.get('min_token_length', 3),
    tokenizer_config.get('lowercase', True),
    tokenizer_config.get('stem', False),
    set(tokenizer_config.get('stopwords', [])) if tokenizer_config.get('stopwords') else None
)
questions = get_questions(TRAIN_Q_FILE)
documents = get_all_doc_texts(TRAIN_Q_FILE, TRAIN_BM25_FILE, CORPUS_FILE)
tokenizer.fit(questions + documents)


## Dataset and DataLoader
The `PointWiseDataset` and `DataLoader` classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The `Dataset` is responsible for accessing and processing single instances of data.

In [None]:
print("Loading training dataset...")
train_dataset = PointWiseDataset(TRAIN_Q_FILE, TRAIN_BM25_FILE, CORPUS_FILE, tokenizer, use_negative_sampling=True)

print("Loading validation dataset...")
validation_dataset = PointWiseDataset(QUESTIONS_FILE, BM25_FILE, CORPUS_FILE, tokenizer)

# Determine max lengths
q_lens = [len(train_dataset[i]["question_token_ids"]) for i in range(len(train_dataset))]
d_lens = [len(train_dataset[i]["document_token_ids"]) for i in range(len(train_dataset))]
max_q_len = max(3, max(q_lens))
max_d_len = max(3, max(d_lens))
print(f"Max Q length: {max_q_len}, Max D length: {max_d_len}")

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=build_collate_fn(tokenizer, max_q_len, max_d_len),
    shuffle=True,
    pin_memory=(DEVICE.type == 'cuda')
)

validation_loader = DataLoader(
    validation_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=build_collate_fn(tokenizer, max_q_len, max_d_len),
    pin_memory=(DEVICE.type == "cuda")
)   


## Training Procedure

We train the reranking model in a pointwise classification setup, where the model learns to predict whether a retrieved document is **relevant (positive)** or **not relevant (negative)** to a given query.

**1. Input Data**
- Each **query** comes with a **set of retrieved documents** (from BM25).
- Each (query, document) pair is labeled:
    - **Positive (label=1)**: if the document is listed in the **gold-standard set** for that query.
    - **Negative (label=0)**: otherwise.

This produces training samples of the form:

```sh
(query_text, document_text, label)
```

**2. Negative/Positive Ratio**

- Since BM25 retrieves many more non-relevant documents than relevant ones, the dataset is **highly imbalanced**.
- To counter this, we use a **negative sampling strategy**:
- For every **positive document**, we sample up to **k negatives** (e.g., 2× more negatives).
    - This yields an approximate **1:2 ratio** of positives to negatives.
    - This prevents the model from being biased toward always predicting "non-relevant".

**3. Negative Mining Approach**

We start with **BM25 retrieval** as a source of candidate documents:

- Positives are guaranteed to be included (if retrieved).
- Negatives are chosen from the **top BM25 results that are not in the gold standard**.
    - These negatives are **“hard negatives”**, because they were highly ranked by BM25 but are not relevant.
    - Training on these improves discrimination compared to random negatives.

In [None]:
def train_model(model, epochs=1, lr=1e-3):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_vloss = 1_000_000.

    model = model.to(DEVICE)

    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch + 1))
        model.train(True)

        running_loss = 0.
        last_loss = 0.

        for i, data in enumerate(train_loader):
            qids = data["question_token_ids"].to(DEVICE, non_blocking=True)
            dids = data["document_token_ids"].to(DEVICE, non_blocking=True)
            labels = data["label"].to(DEVICE, non_blocking=True)

            optimizer.zero_grad()
            scores = model(qids, dids)
            loss = loss_fn(scores, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if i % 1000 == 999:
                last_loss = running_loss / 1000
                print('  batch {} loss: {}'.format(i + 1, last_loss))
                running_loss = 0.

        running_vloss = 0.0
        # Set the model to evaluation mode, disabling dropout and using population
        # statistics for batch normalization.
        model.eval()

        reranked_results = {}

        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(validation_loader):
                q_tokens = vdata["question_token_ids"].to(DEVICE, non_blocking=True)
                d_tokens = vdata["document_token_ids"].to(DEVICE, non_blocking=True)
                labels = vdata["label"].to(DEVICE, non_blocking=True)
                qids = vdata["query_ids"]
                dids = vdata["document_ids"]

                voutputs = model(q_tokens, d_tokens)
                vloss = loss_fn(voutputs, labels)
                running_vloss += vloss.item()

                scores = torch.sigmoid(voutputs).cpu().numpy()

                for i, qid in enumerate(qids):
                    if qid not in reranked_results:
                        reranked_results[qid] = []
                    reranked_results[qid].append((dids[i], float(scores[i])))

        avg_vloss = running_vloss / (i + 1)
        print('LOSS train {} valid {}'.format(last_loss, avg_vloss))

        # sort by model score
        for qid, doc_scores in reranked_results.items():
            reranked_results[qid] = [doc for doc, _ in sorted(doc_scores, key=lambda x: x[1], reverse=True)]

        validation_file = os.path.join(OUTPUT_DIR, 'validation_ranked_questions_model.jsonl')
        with open(validation_file, 'w') as f:
            for qid, docs in reranked_results.items():
                entry = {
                    "query_id": qid,
                    "retrieved_documents": docs
                }
                f.write(ujson.dumps(entry) + '\n')

        print("nDCG@10 (Model):", ndcg.compute_average_ndcg(questions_file_path=QUESTIONS_FILE, results_file_path=validation_file, k=10))

        # Track best performance, and save the model's state
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            model_path = '../output/model/model_{}_{}.pt'.format(timestamp, epoch + 1)
            torch.save(model.state_dict(), model_path)

    return model, model_path

## Create Model

The CNNInteractionBasedModel is a PyTorch neural network model designed to rerank documents based on their relevance to a given query. It leverages convolutional layers to capture interactions between query and document embeddings.

**Key Components**:

- **Embedding Layer**: Converts token IDs into dense vectors. Supports loading pretrained embeddings (e.g., GloVe).
- **Convolutional Layer**: Captures local interactions between query and document embeddings.
- **Activation Function**: Applies ReLU activation to introduce non-linearity.
- **Pooling Layer**: Aggregates features using adaptive max pooling.
- **Fully Connected Layer**: Maps extracted features to a single relevance score.

In [None]:
# Embeddings + Model
if os.path.exists(PRETRAINED_EMB):
    print("Loading pretrained embeddings...")
    pretrained_embeddings = load_pretrained_embeddings(PRETRAINED_EMB, tokenizer, embedding_dim=300)
else:
    print("No pretrained embeddings found, using random init.")
    pretrained_embeddings = None

print("Vocabulary size is: ", tokenizer.vocab_size)

model = CNNInteractionBasedModel(vocab_size=tokenizer.vocab_size, pretrained_embeddings=pretrained_embeddings)


## Training

### Training Setup 

We configure the model training with the following hyperparameters:

- **EPOCHS**: number of times that the model will see the entire training dataset.
- **BATCH_SIZE**: training samples are processed in batches query–document pairs.
- **LR**: learning rate controls how fast the optimizer updates the weights.

The training will automatically use a **GPU (CUDA)** if available, otherwise it will fall back to the **CPU**.

### Training the Model

We call the `train_model` function, which:

**1.** Loads the training dataset (queries, documents, labels).

**2.** Prepares data batches with proper padding.

**3.** Optimizes the model using **binary cross-entropy loss**, where:

- **positive samples (label = 1)**: documents in the gold standard set for a query.
- **negative samples (label = 0)**: retrieved documents not in the gold set.

**4.** Runs the training loop for the defined number of epochs.

**5.** Returns the trained model, ready for reranking/evaluation.

In [None]:
EPOCHS = 10
LR = 1e-3

print("Training new model...")
model, model_path = train_model(model, epochs=EPOCHS, lr=LR)

## Computing Ranking Metrics (BM25 Results) 

The system includes a script to compute the **Normalized Discounted Cumulative Gain (nDCG)** metric, which evaluates the quality of the ranked retrieval results. For this manner, execute the `nDCG.py` script.

#### How nDCG Works

- **DCG (Discounted Cumulative Gain)**: Measures the gain (relevance) of each document in the result list, discounted by its position in the list.
- **IDCG (Ideal DCG)**: The maximum possible DCG achievable, obtained by an ideal ranking of documents.
- **nDCG**: The ratio of DCG to IDCG, normalized to a value between 0 and 1.

In [None]:
# Compute nDCG for the given results
ndcg.compute_average_ndcg(
    questions_file_path=QUESTIONS_FILE,
    results_file_path=BM25_OUPUT_FILE,
    k=10
    )

In [None]:
# Load trained model
model = CNNInteractionBasedModel(vocab_size=tokenizer.vocab_size)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()

In [None]:
# Prepare dataset for reranking
dataset = PointWiseDataset(QUESTIONS_FILE, BM25_FILE, CORPUS_FILE, tokenizer)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=build_collate_fn(tokenizer, max_q_len, max_d_len),
    pin_memory=(DEVICE.type == "cuda")
)

## Reranking with the Neural Model

1. **Batch scoring**  
- For each batch from the DataLoader, we take the tokenized queries and candidate documents.  
- The model outputs a **relevance score** for each (query, document) pair.  

2. **Collect scores per query**  
- We store the `(document_id, score)` pairs for every query.  

3. **Sort candidates**  
- For each query, we sort the candidate documents in descending order of model score.  
- This step produces the final reranked list of documents for each query.  

4. **Save results**  
- The results are saved in a JSONL file with the format:
   ```json
   {
      "query_id": "...",
      "retrieved_documents": ["doc1", "doc2", "doc3", ...]
   }
   ```
- This keeps the same structure as the BM25 file, making it easy to compare baseline vs reranked performance.  

This reranking step does not retrieve new documents — it only **reorders the BM25 shortlist** according to the learned neural model.

In [None]:
# Run reranking
reranked_results = {}

model.eval()

with torch.no_grad():
    for batch in loader:
        q_tokens = batch["question_token_ids"].to(DEVICE)
        d_tokens = batch["document_token_ids"].to(DEVICE)
        qids = batch["query_ids"]
        dids = batch["document_ids"]

        scores = model(q_tokens, d_tokens)
        scores = torch.sigmoid(scores).cpu().numpy()

        for i, qid in enumerate(qids):
            if qid not in reranked_results:
                reranked_results[qid] = []
            reranked_results[qid].append((dids[i], float(scores[i])))

# sort by model score
for qid, doc_scores in reranked_results.items():
    reranked_results[qid] = [doc for doc, _ in sorted(doc_scores, key=lambda x: x[1], reverse=True)]

output_file = os.path.join(OUTPUT_DIR, 'ranked_questions_model.jsonl')
with open(output_file, 'w') as f:
    for qid, docs in reranked_results.items():
        entry = {
            "query_id": qid,
            "retrieved_documents": docs
        }
        f.write(ujson.dumps(entry) + '\n')

print(f"Reranked results saved to: {output_file}")

## Evaluate Retrieved Documents (Model Reranking)

Here we compute the **Normalized Discounted Cumulative Gain (nDCG)** metric, which evaluates the quality of the ranked retrieval results after model reranking.

In [None]:
print(f"Reranked nDCG@10 (Model)")

# Compute nDCG after reranking
ndcg.compute_average_ndcg(
    questions_file_path=QUESTIONS_FILE,
    results_file_path=output_file,
    k=10
    )