<a href="https://colab.research.google.com/github/RegNLP/ContextAware-Regulatory-GraphRAG-ObliQAMP/blob/main/06_Experiment_3_Cross_Encoder_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================================================
# Experiment 3: Full Pipeline with Cross-Encoders
#
# Purpose:
# 1. Use the top two champion hybrid retrievers from Experiment 1 to generate
#    initial sets of candidate passages.
# 2. Re-rank these candidates using both fine-tuned and pre-trained
#    Cross-Encoder models.
# 3. Evaluate the final results to identify the best-performing end-to-end
#    pipeline and measure the impact of fine-tuning.
# ==============================================================================

# --- Essential Installations ---
!pip install -q sentence-transformers pytrec_eval rank_bm25 pandas networkx transformers

import os
import json
import torch
import pickle
import numpy as np
import pandas as pd
import networkx as nx
import pytrec_eval
from tqdm import tqdm
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# --- Configuration ---
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/RIRAG-MultiPassage-NLLP/"
GRAPH_PATH = os.path.join(BASE_PATH, "graph.gpickle")
TEST_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_test.json")
QREL_PATH = os.path.join(BASE_PATH, "qrels.trec")
EMBEDDINGS_FOLDER = os.path.join(BASE_PATH, "embeddings")
RESULTS_OUTPUT_PATH = os.path.join(BASE_PATH, "experiment_3_cross_encoder_results_comparison.csv")

# --- Champion Retriever Configurations (from Experiment 1) ---
CHAMPION_CONFIGS = [
    {
        "name": "e5-large-parent_child",
        "model_key": "e5-large",
        "model_path": "intfloat/e5-large-v2",
        "context_key": "parent_child"
    },
    {
        "name": "e5-large-full_neighborhood",
        "model_key": "e5-large",
        "model_path": "intfloat/e5-large-v2",
        "context_key": "full_neighborhood"
    }
]

# --- Cross-Encoder Models to Evaluate ---
CROSS_ENCODER_FOLDER = os.path.join(BASE_PATH, "fine_tuned_cross_encoders")
CROSS_ENCODERS_TO_EVALUATE = {
    # Fine-tuned models
    "MiniLM_FT": os.path.join(CROSS_ENCODER_FOLDER, "MiniLM_CrossEncoder"),
    "MPNet_FT": os.path.join(CROSS_ENCODER_FOLDER, "MPNet_CrossEncoder"),
    "MSMarco_FT": os.path.join(CROSS_ENCODER_FOLDER, "MSMarco_CrossEncoder"),
    "BERT_FT": os.path.join(CROSS_ENCODER_FOLDER, "BERT_CrossEncoder"),
    # Pre-trained (original) models for comparison
    "MiniLM_Pretrained": "cross-encoder/ms-marco-MiniLM-L-6-v2",
    "MPNet_Pretrained": "sentence-transformers/all-mpnet-base-v2",
    #"MSMarco_Pretrained": "cross-encoder/ms-marco-TinyBERT-L-2-v2",
    "BERT_Pretrained": "bert-base-uncased"
}


# --- Experiment Parameters ---
K_INITIAL = 100  # Number of initial candidates from hybrid retrieval
K_RERANK = 25    # Number of candidates to pass to the cross-encoder
K_FINAL = 20     # Final number of results to evaluate

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Load Static Components ---
print("Loading static components...")
with open(GRAPH_PATH, "rb") as f:
    G = pickle.load(f)

with open(TEST_SET_PATH, "r", encoding="utf-8") as f:
    test_data = json.load(f)
print(f"Loaded {len(test_data)} test questions.")

# Load QRELs
qrel = {}
with open(QREL_PATH, "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split()
        qid, _, uid, rel = parts[0], parts[1], " ".join(parts[2:-1]), int(parts[-1])
        qrel.setdefault(qid, {})[uid] = rel
print(f"Loaded QRELs for {len(qrel)} queries.")

# Prepare BM25
print("Preparing BM25 index...")
all_passage_uids = [n for n, d in G.nodes(data=True) if d.get("type") == "Passage"]
corpus_texts = [G.nodes[uid].get("text", "") for uid in all_passage_uids]
tokenized_corpus = [text.split() for text in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)
print("BM25 ready.")

# --- Helper Functions ---
def evaluate_run(run_dict, qrel_dict, metrics={"recall_10", "map_cut_10", "ndcg_cut_10"}):
    evaluator = pytrec_eval.RelevanceEvaluator(qrel_dict, metrics)
    results = evaluator.evaluate(run_dict)
    agg = {metric: pytrec_eval.compute_aggregated_measure(metric, [r.get(metric, 0.0) for r in results.values()]) for metric in metrics}
    return agg

def reciprocal_rank_fusion(ranked_lists, k=60):
    fused = {}
    for lst in ranked_lists:
        for rank, uid in enumerate(lst):
            fused[uid] = fused.get(uid, 0) + 1 / (k + rank + 1)
    return sorted(fused.keys(), key=lambda item: fused[item], reverse=True)

# --- Main Experiment Loop ---
all_results = []

for champion_config in CHAMPION_CONFIGS:
    champion_name = champion_config["name"]
    champion_model_key = champion_config["model_key"]
    champion_model_path = champion_config["model_path"]
    champion_context_key = champion_config["context_key"]

    print("\n" + "="*80)
    print(f"--- TESTING CHAMPION RETRIEVER: {champion_name} ---")
    print("="*80)

    # --- Load Champion Embeddings & Query Encoder ---
    print(f"Loading champion retriever: {champion_model_key} / {champion_context_key}")
    emb_path = os.path.join(EMBEDDINGS_FOLDER, champion_model_key, champion_context_key, "embeddings.pkl")
    id_path = os.path.join(EMBEDDINGS_FOLDER, champion_model_key, champion_context_key, "passage_ids.json")
    try:
        with open(emb_path, "rb") as f:
            passage_embeddings = pickle.load(f)
        with open(id_path, "r") as f:
            passage_ids = json.load(f)
        embeddings_tensor = torch.tensor(passage_embeddings).to(device)
        query_encoder = SentenceTransformer(champion_model_path, device=device)
        print("Champion retriever loaded successfully.")
    except FileNotFoundError:
        print(f"FATAL ERROR: Champion embeddings not found at {emb_path}. Skipping this champion.")
        continue

    # --- Pre-calculate initial retrievals for all queries ---
    print("\nPre-calculating initial hybrid retrievals for all queries...")
    initial_retrievals = {}
    for q in tqdm(test_data, desc=f"Initial Hybrid Retrieval for {champion_name}"):
        qid = q["QuestionID"]
        query = q["Question"]
        # Dense retrieval
        query_emb = query_encoder.encode(query, convert_to_tensor=True, device=device)
        cos_scores = util.pytorch_cos_sim(query_emb, embeddings_tensor)[0]
        top_dense = torch.topk(cos_scores, k=min(K_INITIAL, len(passage_ids)))
        dense_uids = [passage_ids[idx] for idx in top_dense.indices]
        # Lexical retrieval
        bm25_scores = bm25.get_scores(query.split())
        top_bm25 = np.argsort(bm25_scores)[::-1][:K_INITIAL]
        bm25_uids = [all_passage_uids[i] for i in top_bm25]
        # Fusion
        fused_uids = reciprocal_rank_fusion([dense_uids, bm25_uids])
        initial_retrievals[qid] = fused_uids[:K_RERANK]

    for ce_name, ce_path in CROSS_ENCODERS_TO_EVALUATE.items():
        print("\n" + "-"*80)
        print(f"--- Evaluating Cross-Encoder: {ce_name} (with {champion_name}) ---")
        print("-" * 80)

        try:
            ce_tokenizer = AutoTokenizer.from_pretrained(ce_path)
            ce_model = AutoModelForSequenceClassification.from_pretrained(ce_path).to(device)
            ce_model.eval()
        except Exception as e:
            print(f"⚠️  Could not load model from {ce_path}. Skipping. Error: {e}")
            continue

        final_run = {}
        for q in tqdm(test_data, desc=f"Re-ranking with {ce_name}"):
            qid = q["QuestionID"]
            query = q["Question"]

            candidates_uids = initial_retrievals[qid]

            # --- Stage 2: Cross-Encoder Re-ranking ---
            ce_input_pairs = [[query, G.nodes[uid].get("text", "")] for uid in candidates_uids]

            reranked_candidates = []
            with torch.no_grad():
                inputs = ce_tokenizer(ce_input_pairs, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
                logits = ce_model(**inputs).logits

                # CORRECTED: Handle both regression (1 output) and classification (2 outputs)
                if logits.shape[1] > 1:
                    # Classification model: get the score for the "relevant" class (index 1)
                    scores = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
                else:
                    # Regression model: use the single output score directly
                    scores = logits.squeeze().cpu().numpy()

            for i, uid in enumerate(candidates_uids):
                reranked_candidates.append({"internal_uid": uid, "score": scores[i]})

            # Sort by the new cross-encoder score
            reranked_candidates = sorted(reranked_candidates, key=lambda x: x["score"], reverse=True)

            # Format for evaluation
            final_run[qid] = {}
            for cand in reranked_candidates[:K_FINAL]:
                node = G.nodes[cand["internal_uid"]]
                doc_id = node.get("document_id", "")
                passage_id = node.get("passage_id", "")
                if doc_id and passage_id:
                    combined_uid = f"{doc_id}|||{passage_id}"
                    final_run[qid][combined_uid] = float(cand["score"])

        # Evaluate and store results for this cross-encoder
        ce_metrics = evaluate_run(final_run, qrel)
        all_results.append({
            "Retriever": champion_name,
            "Cross-Encoder": ce_name,
            "Recall@10": ce_metrics["recall_10"],
            "MAP@10": ce_metrics["map_cut_10"],
            "nDCG@10": ce_metrics["ndcg_cut_10"]
        })

# --- Save and Display Final Results ---
results_df = pd.DataFrame(all_results)
results_df = results_df.sort_values(by="nDCG@10", ascending=False)

print("\n📊 Final Evaluation Results:")
print(results_df.to_string(index=False, float_format="%.4f"))

results_df.to_csv(RESULTS_OUTPUT_PATH, index=False)
print(f"\n✅ Results saved to: {RESULTS_OUTPUT_PATH}")

print("\n--- 🏆 Best Performing Full Pipeline ---")
print(results_df.iloc[0])
