<a href="https://colab.research.google.com/github/RegNLP/ContextAware-Regulatory-GraphRAG-ObliQAMP/blob/main/06_Experiment_3_Cross_Encoder_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================================================
# Experiment 3: Cross Encoder Pipeline Evaluation
#
# Purpose:
# 1. Use the best pre-trained, standard fine-tuned, and advanced fine-tuned
#    hybrid retrievers to generate initial candidate passages.
# 2. Re-rank these candidates using both fine-tuned and pre-trained
#    Cross-Encoder models.
# 3. Evaluate the final results to identify the definitive best-performing
#    end-to-end pipeline.
# 4. Save both a CSV summary and detailed JSON outputs for each run.
# ==============================================================================

# --- Essential Installations ---
!pip install -q -U sentence-transformers transformers datasets rank_bm25 pytrec_eval

import os
import json
import torch
import pickle
import numpy as np
import pandas as pd
import networkx as nx
import pytrec_eval
from tqdm import tqdm
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# --- Configuration ---
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/RIRAG-MultiPassage-NLLP/"
GRAPH_PATH = os.path.join(BASE_PATH, "graph.gpickle")
TEST_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_test.json")
QREL_PATH = os.path.join(BASE_PATH, "qrels.trec")

# --- Model & Embedding Input Folders ---
FINETUNED_RETRIEVER_FOLDER = os.path.join(BASE_PATH, "fine_tuned_retrievers")
ADVANCED_FINETUNED_FOLDER = os.path.join(BASE_PATH, "fine_tuned_retrievers_advanced")
EMBEDDINGS_FOLDER = os.path.join(BASE_PATH, "embeddings_full_comparison")
CROSS_ENCODER_FOLDER = os.path.join(BASE_PATH, "fine_tuned_cross_encoders")

# --- Output Paths ---
RESULTS_CSV_OUTPUT_PATH = os.path.join(BASE_PATH, "experiment_3_final_pipeline_results.csv")
RESULTS_JSON_OUTPUT_FOLDER = os.path.join(BASE_PATH, "experiment_3_retrieval_results_json")
os.makedirs(RESULTS_JSON_OUTPUT_FOLDER, exist_ok=True)


# --- Champion Retriever Configurations (from new Experiment 1) ---
# We select the best context for each type of model
CHAMPION_CONFIGS = [
    {"name": "e5-large-v2_FT_Advanced_parent", "model_path": os.path.join(ADVANCED_FINETUNED_FOLDER, "e5-large-v2"), "context_key": "parent"},
    {"name": "all-mpnet-base-v2_FT_Advanced_parent", "model_path": os.path.join(ADVANCED_FINETUNED_FOLDER, "all-mpnet-base-v2"), "context_key": "parent"},
    {"name": "bge-base-en-v1.5_FT_Advanced_parent_child", "model_path": os.path.join(ADVANCED_FINETUNED_FOLDER, "bge-base-en-v1.5"), "context_key": "parent_child"},
    {"name": "e5-large-v2_FT_parent", "model_path": os.path.join(FINETUNED_RETRIEVER_FOLDER, "e5-large-v2"), "context_key": "parent"},
    {"name": "e5-large-v2_Pretrained_parent_child", "model_path": "intfloat/e5-large-v2", "context_key": "parent_child"},
]

# --- Cross-Encoder Models to Evaluate ---
CROSS_ENCODERS_TO_EVALUATE = {
    "MiniLM_FT": os.path.join(CROSS_ENCODER_FOLDER, "MiniLM_CrossEncoder"),
    "MPNet_FT": os.path.join(CROSS_ENCODER_FOLDER, "MPNet_CrossEncoder"),
    "MSMarco_FT": os.path.join(CROSS_ENCODER_FOLDER, "MSMarco_CrossEncoder"),
    "BERT_FT": os.path.join(CROSS_ENCODER_FOLDER, "BERT_CrossEncoder"),
    "MiniLM_Pretrained": "cross-encoder/ms-marco-MiniLM-L-6-v2",
    "MSMarco_Pretrained": "cross-encoder/ms-marco-TinyBERT-L-2-v2",
    "BERT_Pretrained": "bert-base-uncased"
}

# --- Experiment Parameters ---
K_INITIAL = 100
K_RERANK = 25
K_FINAL = 20

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Load Static Components ---
print("Loading static components...")
with open(GRAPH_PATH, "rb") as f: G = pickle.load(f)
with open(TEST_SET_PATH, "r", encoding="utf-8") as f: test_data = json.load(f)
qid_to_question = {q["QuestionID"]: q["Question"] for q in test_data}
print(f"Loaded {len(test_data)} test questions.")

# Load QRELs
qrel = {}
with open(QREL_PATH, "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split()
        qid, _, uid, rel = parts[0], parts[1], " ".join(parts[2:-1]), int(parts[-1])
        qrel.setdefault(qid, {})[uid] = rel
print(f"Loaded QRELs for {len(qrel)} queries.")

# Prepare BM25
print("Preparing BM25 index...")
all_passage_uids = [n for n, d in G.nodes(data=True) if d.get("type") == "Passage"]
uid_map = {f"{G.nodes[uid].get('document_id')}|||{G.nodes[uid].get('passage_id')}": uid for uid in all_passage_uids}
corpus_texts = [G.nodes[uid].get("text", "") for uid in all_passage_uids]
tokenized_corpus = [text.split() for text in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)
print("BM25 ready.")

# --- Helper Functions ---
def add_instruction_to_query(query, model_name):
    if "e5" in model_name or "bge" in model_name:
        return f"query: {query}"
    return query

def evaluate_run(run_dict, qrel_dict, metrics={"recall_10", "map_cut_10", "ndcg_cut_10"}):
    evaluator = pytrec_eval.RelevanceEvaluator(qrel_dict, metrics)
    results = evaluator.evaluate(run_dict)
    agg = {metric: pytrec_eval.compute_aggregated_measure(metric, [r.get(metric, 0.0) for r in results.values()]) for metric in metrics}
    return agg

def reciprocal_rank_fusion(ranked_lists, k=60):
    fused = {}
    for lst in ranked_lists:
        for rank, uid in enumerate(lst):
            fused[uid] = fused.get(uid, 0) + 1 / (k + rank + 1)
    return sorted(fused.keys(), key=lambda item: fused[item], reverse=True)

def format_run_for_json(run_dict, qid_to_question_map, uid_to_internal_uid_map, graph, top_n=10):
    output_list = []
    for qid, passages in run_dict.items():
        sorted_passages = sorted(passages.items(), key=lambda item: item[1], reverse=True)

        retrieved_passages_text, retrieved_scores, retrieved_ids = [], [], []
        for combined_uid, score in sorted_passages[:top_n]:
            internal_uid = uid_to_internal_uid_map.get(combined_uid)
            if internal_uid:
                retrieved_passages_text.append(graph.nodes[internal_uid].get("text", ""))
                retrieved_scores.append(score)
                retrieved_ids.append(internal_uid)

        output_list.append({
            "QuestionID": qid, "Question": qid_to_question_map.get(qid, ""),
            "RetrievedPassages": retrieved_passages_text, "RetrievedScores": retrieved_scores,
            "RetrievedIDs": retrieved_ids
        })
    return output_list

# --- Main Experiment Loop ---
all_results = []

for config in CHAMPION_CONFIGS:
    champion_name = config["name"]
    model_path = config["model_path"]
    context_key = config["context_key"]
    model_key = "_".join(champion_name.split('_')[:-1]) # e.g., e5-large-v2_FT_Advanced

    print("\n" + "="*80)
    print(f"--- TESTING CHAMPION RETRIEVER: {champion_name} ---")
    print("="*80)

    # Load Embeddings & Query Encoder
    print(f"Loading embeddings from: {EMBEDDINGS_FOLDER}")
    emb_path = os.path.join(EMBEDDINGS_FOLDER, model_key, context_key, "embeddings.pkl")
    id_path = os.path.join(EMBEDDINGS_FOLDER, model_key, context_key, "passage_ids.json")
    try:
        with open(emb_path, "rb") as f: passage_embeddings = pickle.load(f)
        with open(id_path, "r") as f: passage_ids = json.load(f)
        embeddings_tensor = torch.tensor(passage_embeddings).to(device)
        query_encoder = SentenceTransformer(model_path, device=device)
        print("Champion retriever components loaded successfully.")
    except FileNotFoundError:
        print(f"FATAL ERROR: Embeddings not found at {emb_path}. Skipping this champion.")
        continue

    # Pre-calculate initial retrievals
    print("Pre-calculating initial hybrid retrievals...")
    initial_retrievals = {}
    for q in tqdm(test_data, desc=f"Initial Retrieval for {champion_name}"):
        qid, query = q["QuestionID"], q["Question"]
        instructed_query = add_instruction_to_query(query, champion_name)
        query_emb = query_encoder.encode(instructed_query, convert_to_tensor=True, device=device)
        cos_scores = util.pytorch_cos_sim(query_emb, embeddings_tensor)[0]
        top_dense = torch.topk(cos_scores, k=min(K_INITIAL, len(passage_ids)))
        dense_uids = [passage_ids[idx] for idx in top_dense.indices]

        bm25_scores = bm25.get_scores(query.split())
        top_bm25 = np.argsort(bm25_scores)[::-1][:K_INITIAL]
        bm25_uids = [all_passage_uids[i] for i in top_bm25]

        fused_uids = reciprocal_rank_fusion([dense_uids, bm25_uids])
        initial_retrievals[qid] = fused_uids[:K_RERANK]

    for ce_name, ce_path in CROSS_ENCODERS_TO_EVALUATE.items():
        print("\n" + "-"*80)
        print(f"--- Evaluating Cross-Encoder: {ce_name} (with {champion_name}) ---")
        print("-" * 80)

        try:
            ce_tokenizer = AutoTokenizer.from_pretrained(ce_path)
            ce_model = AutoModelForSequenceClassification.from_pretrained(ce_path).to(device)
            ce_model.eval()
        except Exception as e:
            print(f"⚠️  Could not load model from {ce_path}. Skipping. Error: {e}")
            continue

        final_run = {}
        for q in tqdm(test_data, desc=f"Re-ranking with {ce_name}"):
            qid, query = q["QuestionID"], q["Question"]
            candidates_uids = initial_retrievals[qid]

            ce_input_pairs = [[query, G.nodes[uid].get("text", "")] for uid in candidates_uids]

            reranked_candidates = []
            with torch.no_grad():
                inputs = ce_tokenizer(ce_input_pairs, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
                logits = ce_model(**inputs).logits

                if logits.shape[1] > 1:
                    scores = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
                else:
                    scores = logits.squeeze().cpu().numpy()

            # Handle case where only one candidate is passed
            if scores.ndim == 0:
                scores = [scores.item()]

            for i, uid in enumerate(candidates_uids):
                reranked_candidates.append({"internal_uid": uid, "score": scores[i]})

            reranked_candidates = sorted(reranked_candidates, key=lambda x: x["score"], reverse=True)

            final_run[qid] = {}
            for cand in reranked_candidates[:K_FINAL]:
                node = G.nodes[cand["internal_uid"]]
                combined_uid = f"{node.get('document_id')}|||{node.get('passage_id')}"
                final_run[qid][combined_uid] = float(cand["score"])

        ce_metrics = evaluate_run(final_run, qrel)
        all_results.append({
            "Retriever": champion_name, "Cross-Encoder": ce_name, **ce_metrics
        })

        # Save JSON output for this run
        json_output_path = os.path.join(RESULTS_JSON_OUTPUT_FOLDER, f"{champion_name}_{ce_name}_results.json")
        json_data = format_run_for_json(final_run, qid_to_question, uid_map, G)
        with open(json_output_path, 'w') as f:
            json.dump(json_data, f, indent=4)

# --- Save and Display Final Results ---
df = pd.DataFrame(all_results)
df = df.rename(columns={"recall_10": "Recall@10", "map_cut_10": "MAP@10"})
df = df.sort_values(by=["Recall@10", "MAP@10"], ascending=False)

print("\n📊 Final Evaluation Results:")
print(df.to_string(index=False, float_format="%.4f"))

df.to_csv(RESULTS_CSV_OUTPUT_PATH, index=False)
print(f"\n✅ CSV summary saved to: {RESULTS_CSV_OUTPUT_PATH}")
print(f"✅ Detailed JSON results saved to: {RESULTS_JSON_OUTPUT_FOLDER}")

print("\n--- 🏆 Best Performing Full Pipeline ---")
print(df.iloc[0])
