<a href="https://colab.research.google.com/github/RegNLP/ContextAware-Regulatory-GraphRAG-ObliQAMP/blob/main/00_Advanced_Fine_Tuning_for_Dense_Retrievers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================================================
# Advanced Fine-Tuning for Dense Retriever Models
#
# Purpose:
# 1. Load training and validation datasets.
# 2. Generate a high-quality training set by mining "hard negatives" using BM25.
# 3. Apply instruction-tuning prefixes for compatible models (e.g., e5, bge).
# 4. Fine-tune models using OnlineContrastiveLoss with (query, positive, negative)
#    triplets.
# 5. Use a validation set to save only the best-performing
#    model checkpoint and prevent overfitting.
# ==============================================================================

# --- Essential Installations ---
!pip install -q -U sentence-transformers transformers datasets rank_bm25

import os
import json
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from rank_bm25 import BM25Okapi

# --- Configuration ---
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/RIRAG-MultiPassage-NLLP/"
TRAIN_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_train.json")
DEV_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_val.json")
# Save to a new folder to distinguish from the basic fine-tuned models
MODEL_OUTPUT_FOLDER = os.path.join(BASE_PATH, "fine_tuned_retrievers_advanced")
os.makedirs(MODEL_OUTPUT_FOLDER, exist_ok=True)

# --- Models to Fine-Tune ---
BASE_MODELS_TO_FINETUNE = {
    #"e5-large-v2": "intfloat/e5-large-v2",
    "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
    "bge-base-en-v1.5": "BAAI/bge-base-en-v1.5"
}

# --- Training Parameters ---
NUM_EPOCHS = 3 # Increased epochs, with early stopping to prevent overfitting
BATCH_SIZE = 8 # Smaller batch size is often needed for contrastive loss
LEARNING_RATE = 1e-5 # Lower learning rate for fine-tuning
NUM_HARD_NEGATIVES_PER_POSITIVE = 2 # Mine 2 hard negatives for each positive passage

# --- Load Data ---
print("Loading training and validation data...")
try:
    with open(TRAIN_SET_PATH, "r", encoding="utf-8") as f:
        train_data = json.load(f)
    with open(DEV_SET_PATH, "r", encoding="utf-8") as f:
        dev_data = json.load(f)
    print(f"Loaded {len(train_data)} training and {len(dev_data)} validation examples.")
except FileNotFoundError as e:
    print(f"FATAL ERROR: Data file not found: {e}. Cannot proceed.")
    exit()

# --- Prepare for Hard Negative Mining ---
print("Preparing for hard negative mining...")
# Create a corpus of all unique passages from the training data
all_passages = {p["PassageID"]: p["Passage"] for item in train_data for p in item["Passages"]}
corpus_pids = list(all_passages.keys())
corpus_texts = [all_passages[pid] for pid in corpus_pids]
tokenized_corpus = [text.split(" ") for text in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)
print("BM25 index built for hard negative mining.")

# --- Function for Instruction Prefixes ---
def add_instruction(text, model_key, text_type="query"):
    if "e5" in model_key:
        return f"{text_type}: {text}"
    # BGE models have a specific instruction for queries only
    if "bge" in model_key and text_type == "query":
        return f"Represent this sentence for searching relevant passages: {text}"
    return text

# --- Create Training Set with Hard Negatives ---
train_samples = []
print("Mining hard negatives and creating training samples...")
for item in tqdm(train_data, desc="Mining negatives"):
    query = item["Question"]
    positive_pids = {p["PassageID"] for p in item["Passages"]}

    # Find hard negatives for this query
    tokenized_query = query.split(" ")
    bm25_scores = bm25.get_scores(tokenized_query)
    top_n_indices = np.argsort(bm25_scores)[::-1][:100] # Get top 100 candidates

    hard_negatives = []
    for idx in top_n_indices:
        pid = corpus_pids[idx]
        if pid not in positive_pids:
            hard_negatives.append(corpus_texts[idx])
            if len(hard_negatives) >= NUM_HARD_NEGATIVES_PER_POSITIVE * len(positive_pids):
                break

    # Create training triplets
    for p in item["Passages"]:
        positive_passage = p["Passage"]
        if hard_negatives:
            for neg_passage in hard_negatives:
                train_samples.append(InputExample(texts=[query, positive_passage, neg_passage]))
        else:
            # Fallback if no hard negatives are found
            train_samples.append(InputExample(texts=[query, positive_passage, ""]))


print(f"Created {len(train_samples)} training triplets.")

# --- Prepare Validation Data ---
print("Preparing validation data...")
dev_queries = {item["QuestionID"]: item["Question"] for item in dev_data}
dev_corpus = {p["PassageID"]: p["Passage"] for item in dev_data for p in item["Passages"]}
dev_relevant_docs = {item["QuestionID"]: {p["PassageID"] for p in item["Passages"]} for item in dev_data}
print("Validation data prepared.")

# --- Main Fine-Tuning Loop ---
for model_name, model_path in BASE_MODELS_TO_FINETUNE.items():
    print("\n" + "="*80)
    print(f"--- Fine-Tuning Model: {model_name} ---")
    print(f"Base model: {model_path}")
    print("="*80)

    # 1. Initialize the SentenceTransformer model
    model = SentenceTransformer(model_path)

    # 2. Apply instruction prefixes to the training data
    instructed_train_samples = []
    for sample in train_samples:
        instructed_query = add_instruction(sample.texts[0], model_name, "query")
        instructed_pos = add_instruction(sample.texts[1], model_name, "passage")
        instructed_neg = add_instruction(sample.texts[2], model_name, "passage")
        instructed_train_samples.append(InputExample(texts=[instructed_query, instructed_pos, instructed_neg]))

    # 3. Create a DataLoader
    train_dataloader = DataLoader(instructed_train_samples, shuffle=True, batch_size=BATCH_SIZE)

    # 4. Define the loss function for triplet data
    train_loss = losses.OnlineContrastiveLoss(model)

    # 5. Create the evaluator with instruction prefixes
    instructed_dev_queries = {qid: add_instruction(q_text, model_name, "query") for qid, q_text in dev_queries.items()}
    instructed_dev_corpus = {pid: add_instruction(p_text, model_name, "passage") for pid, p_text in dev_corpus.items()}
    evaluator = InformationRetrievalEvaluator(
        queries=instructed_dev_queries,
        corpus=instructed_dev_corpus,
        relevant_docs=dev_relevant_docs,
        name=f"{model_name}-val",
        show_progress_bar=True
    )

    # 6. Define the output path
    output_save_path = os.path.join(MODEL_OUTPUT_FOLDER, model_name)

    # 7. Start the training process
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=evaluator,
        epochs=NUM_EPOCHS,
        warmup_steps=int(len(train_dataloader) * 0.1),
        output_path=output_save_path,
        save_best_model=True,
        show_progress_bar=True,
        evaluation_steps=int(len(train_dataloader) * 0.25) # Evaluate 4 times per epoch
    )

    print(f"✅ Best model for '{model_name}' fine-tuned and saved to: {output_save_path}")

print("\nAll dense retriever models have been fine-tuned successfully.")
