<a href="https://colab.research.google.com/github/RegNLP/ContextAware-Regulatory-GraphRAG-ObliQAMP/blob/main/03_Fine_Tune_Cross_Encoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================================================
# Fine-Tune Cross-Encoder Models for Re-ranking (High Quality)
#
# Purpose:
# 1. Load training and validation datasets.
# 2. Generate high-quality training data by mining "hard negatives" using BM25.
# 3. Fine-tune each base cross-encoder model, using the validation set to
#    evaluate performance after each epoch and save only the best model.
# 4. Save each fine-tuned model to a designated folder for Experiment 3.
#
# This script should be run in a Google Colab environment with the Drive mounted.
# ==============================================================================

# --- Essential Installations ---
# This command ensures compatible library versions to prevent import errors.
!pip install -q -U sentence-transformers transformers datasets rank_bm25

import os
import json
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from sentence_transformers import CrossEncoder, InputExample
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from rank_bm25 import BM25Okapi

# --- Configuration ---
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/RIRAG-MultiPassage-NLLP/"
TRAIN_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_train.json")
# Assumes you have a dev/validation set in the same folder
DEV_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_val.json")
MODEL_OUTPUT_FOLDER = os.path.join(BASE_PATH, "fine_tuned_cross_encoders")
os.makedirs(MODEL_OUTPUT_FOLDER, exist_ok=True)

# --- Models to Fine-Tune ---
BASE_MODELS_TO_FINETUNE = {
    "MiniLM_CrossEncoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
    "MPNet_CrossEncoder": "sentence-transformers/all-mpnet-base-v2",
    "MSMarco_CrossEncoder": "cross-encoder/ms-marco-TinyBERT-L-2-v2",
    "BERT_CrossEncoder": "bert-base-uncased"
}

# --- Training Parameters ---
NUM_EPOCHS = 10 # Increased epochs for better convergence with harder negatives
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
NUM_HARD_NEGATIVES = 5 # Number of hard negatives to mine for each positive example

# --- Load Data ---
print("Loading training and validation data...")
try:
    with open(TRAIN_SET_PATH, "r", encoding="utf-8") as f:
        train_data = json.load(f)
    with open(DEV_SET_PATH, "r", encoding="utf-8") as f:
        dev_data = json.load(f)
    print(f"Loaded {len(train_data)} training and {len(dev_data)} validation examples.")
except FileNotFoundError as e:
    print(f"FATAL ERROR: Data file not found. {e}. Cannot proceed.")
    exit()

# --- Prepare for Hard Negative Mining ---
print("Preparing for hard negative mining...")
# Create a corpus of all unique passages from the training data
all_passages = {}
for item in train_data:
    for p in item["Passages"]:
        all_passages[p["PassageID"]] = p["Passage"]

corpus_pids = list(all_passages.keys())
corpus_texts = [all_passages[pid] for pid in corpus_pids]
tokenized_corpus = [text.split(" ") for text in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)
print("BM25 index built for hard negative mining.")

# --- Function to Prepare Samples (with Hard Negative Mining) ---
def prepare_samples(data, bm25_model, corpus_pids, corpus, num_negatives):
    samples = []
    for item in tqdm(data, desc="Generating samples"):
        query = item["Question"]
        positive_pids = {p["PassageID"] for p in item["Passages"]}

        # Create positive examples
        for p in item["Passages"]:
            samples.append(InputExample(texts=[query, p["Passage"]], label=1.0))

        # Mine hard negatives
        tokenized_query = query.split(" ")
        bm25_scores = bm25_model.get_scores(tokenized_query)
        top_n_indices = np.argsort(bm25_scores)[::-1][:100] # Get top 100 candidates

        hard_negatives_added = 0
        for idx in top_n_indices:
            pid = corpus_pids[idx]
            if pid not in positive_pids:
                passage_text = corpus[idx]
                samples.append(InputExample(texts=[query, passage_text], label=0.0))
                hard_negatives_added += 1
                if hard_negatives_added >= num_negatives * len(positive_pids):
                    break
    return samples

# --- Create Training and Validation Sets ---
train_samples = prepare_samples(train_data, bm25, corpus_pids, corpus_texts, NUM_HARD_NEGATIVES)
dev_samples = prepare_samples(dev_data, bm25, corpus_pids, corpus_texts, NUM_HARD_NEGATIVES)

print(f"Created {len(train_samples)} training samples.")
print(f"Created {len(dev_samples)} validation samples.")

# --- Main Fine-Tuning Loop ---
for model_name, model_path in BASE_MODELS_TO_FINETUNE.items():
    print("\n" + "="*80)
    print(f"--- Fine-Tuning Model: {model_name} ---")
    print(f"Base model: {model_path}")
    print("="*80)

    # 1. Initialize the CrossEncoder model
    model = CrossEncoder(model_path, num_labels=1)

    # 2. Create DataLoaders
    train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=BATCH_SIZE)

    # 3. Create Evaluator
    evaluator = CEBinaryClassificationEvaluator.from_input_examples(dev_samples, name=f"{model_name}-dev")

    # 4. Define the output path for this model
    output_save_path = os.path.join(MODEL_OUTPUT_FOLDER, model_name)

    # 5. Start the training process
    model.fit(
        train_dataloader=train_dataloader,
        evaluator=evaluator,
        epochs=NUM_EPOCHS,
        evaluation_steps=int(len(train_dataloader) * 0.1), # Evaluate every 10% of an epoch
        warmup_steps=int(len(train_dataloader) * 0.1),
        output_path=output_save_path,
        save_best_model=True, # This is crucial for saving the best performing model
        show_progress_bar=True
    )

    print(f"✅ Best model for '{model_name}' fine-tuned and saved to: {output_save_path}")

print("\nAll cross-encoder models have been fine-tuned successfully.")


In [None]:
## Import up sound alert dependencies
from IPython.display import Audio, display

def allDone():
  #display(Audio(url='https://www.myinstants.com/media/sounds/anime-wow-sound-effect.mp3', autoplay=True))
  display(Audio(url='https://www.myinstants.com/media/sounds/money-soundfx.mp3', autoplay=True))
## Insert whatever audio file you want above

allDone()