<a href="https://colab.research.google.com/github/RegNLP/ContextAware-Regulatory-GraphRAG-ObliQAMP/blob/main/00_Fine_Tune_Dense_Retriever_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ==============================================================================
# Fine-Tune Dense Retriever Models
#
# Purpose:
# 1. Load a training dataset containing questions and their corresponding
#    relevant passages.
# 2. For each base sentence-transformer model specified, fine-tune it on the
#    training data using Multiple Negatives Ranking Loss.
# 3. Use a validation set to evaluate the model during training and save the
#    best-performing checkpoint.
# 4. Save each fine-tuned model to a designated folder, making them
#    ready for re-evaluation in Experiment 1.
#
# This script should be run in a Google Colab environment with the Drive mounted.
# ==============================================================================

# --- Essential Installations ---
# This command ensures compatible library versions to prevent import errors.
!pip install -q -U sentence-transformers transformers datasets

import os
import json
from tqdm import tqdm
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator

# --- Configuration ---
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/RIRAG-MultiPassage-NLLP/"
TRAIN_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_train.json")
# Path to the validation (development) set
DEV_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_val.json")
# This is the output folder where the new models will be saved
MODEL_OUTPUT_FOLDER = os.path.join(BASE_PATH, "fine_tuned_retrievers")
os.makedirs(MODEL_OUTPUT_FOLDER, exist_ok=True)

# --- Models to Fine-Tune ---
# A selection of strong base models for dense retrieval
BASE_MODELS_TO_FINETUNE = {
    "e5-large-v2": "intfloat/e5-large-v2",
    "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
    "bge-base-en-v1.5": "BAAI/bge-base-en-v1.5"
}

# --- Training Parameters ---
NUM_EPOCHS = 3
BATCH_SIZE = 16 # A larger batch size is beneficial for MultipleNegativesRankingLoss
LEARNING_RATE = 2e-5

# --- Load and Prepare Data ---
print("Loading training and validation data...")
try:
    with open(TRAIN_SET_PATH, "r", encoding="utf-8") as f:
        train_data = json.load(f)
    with open(DEV_SET_PATH, "r", encoding="utf-8") as f:
        dev_data = json.load(f)
    print(f"Loaded {len(train_data)} training and {len(dev_data)} validation examples.")
except FileNotFoundError as e:
    print(f"FATAL ERROR: Data file not found: {e}. Cannot proceed.")
    exit()

# Convert training data to InputExample format for MultipleNegativesRankingLoss
train_samples = []
print("Preparing training samples...")
for item in tqdm(train_data, desc="Processing training data"):
    query = item["Question"]
    for p in item["Passages"]:
        positive_passage = p["Passage"]
        train_samples.append(InputExample(texts=[query, positive_passage]))
print(f"Created {len(train_samples)} positive training pairs.")

# Prepare validation data for the InformationRetrievalEvaluator
print("Preparing validation data...")
dev_queries = {}
dev_corpus = {}
dev_relevant_docs = {}
for item in tqdm(dev_data, desc="Processing validation data"):
    qid = item["QuestionID"]
    dev_queries[qid] = item["Question"]
    dev_relevant_docs[qid] = set()
    for p in item["Passages"]:
        # Use a unique passage identifier for the corpus and relevance mapping
        # Here we assume PassageID is unique across the entire dataset
        pid = p["PassageID"]
        dev_corpus[pid] = p["Passage"]
        dev_relevant_docs[qid].add(pid)
print("Validation data prepared.")


# --- Main Fine-Tuning Loop ---
for model_name, model_path in BASE_MODELS_TO_FINETUNE.items():
    print("\n" + "="*80)
    print(f"--- Fine-Tuning Model: {model_name} ---")
    print(f"Base model: {model_path}")
    print("="*80)

    # 1. Initialize the SentenceTransformer model
    model = SentenceTransformer(model_path)

    # 2. Create a DataLoader
    # The library handles the collation and batching process.
    train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=BATCH_SIZE)

    # 3. Define the loss function
    # MultipleNegativesRankingLoss is the state-of-the-art for training retrievers.
    train_loss = losses.MultipleNegativesRankingLoss(model)

    # 4. Define the output path for this model
    output_save_path = os.path.join(MODEL_OUTPUT_FOLDER, model_name)

    # 5. Create the evaluator
    evaluator = InformationRetrievalEvaluator(
        queries=dev_queries,
        corpus=dev_corpus,
        relevant_docs=dev_relevant_docs,
        name=f"{model_name}-val",
        show_progress_bar=True
    )

    # 6. Start the training process
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=evaluator,
        epochs=NUM_EPOCHS,
        warmup_steps=int(len(train_dataloader) * 0.1), # 10% of steps for warmup
        output_path=output_save_path,
        show_progress_bar=True,
        checkpoint_save_steps=int(len(train_dataloader) * 0.25), # Save a checkpoint 4 times per epoch
        checkpoint_path=os.path.join(output_save_path, "checkpoints"),
        # Save the best model based on its performance on the validation set
        save_best_model=True
    )

    print(f"✅ Best model for '{model_name}' fine-tuned and saved to: {output_save_path}")

print("\nAll dense retriever models have been fine-tuned successfully.")


In [None]:
## Import up sound alert dependencies
from IPython.display import Audio, display

def allDone():
  #display(Audio(url='https://www.myinstants.com/media/sounds/anime-wow-sound-effect.mp3', autoplay=True))
  display(Audio(url='https://www.myinstants.com/media/sounds/money-soundfx.mp3', autoplay=True))
## Insert whatever audio file you want above

allDone()