<a href="https://colab.research.google.com/github/RegNLP/ContextAware-Regulatory-GraphRAG-ObliQAMP/blob/main/05_Experiment_2_Graph_Reranking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# ==============================================================================
# Experiment 2 (Re-run): Enriched Graph Re-ranking with Fine-Tuned Retrievers
#
# Purpose:
# 1. Use the new champion fine-tuned retrievers.
# 2. Systematically test multiple graph re-ranking strategies to see if they can
#    improve upon a strong, fine-tuned baseline.
# 3. Evaluate all combinations to find the optimal graph re-ranking strategy.
# 4. Save both a CSV summary and detailed JSON outputs for each run.
# ==============================================================================

# --- Essential Installations ---
!pip install -q -U sentence-transformers transformers datasets rank_bm25 pytrec_eval

import os
import json
import torch
import pickle
import numpy as np
import pandas as pd
import networkx as nx
import pytrec_eval
from tqdm import tqdm
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util
import itertools

# --- Configuration ---
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/RIRAG-MultiPassage-NLLP/"
GRAPH_PATH = os.path.join(BASE_PATH, "graph.gpickle")
TEST_SET_PATH = os.path.join(BASE_PATH, "QADataset", "ObliQA_MultiPassage_test.json")
QREL_PATH = os.path.join(BASE_PATH, "qrels.trec")

# --- Model & Embedding Input Folders ---
FINETUNED_RETRIEVER_FOLDER = os.path.join(BASE_PATH, "fine_tuned_retrievers")
ADVANCED_FINETUNED_FOLDER = os.path.join(BASE_PATH, "fine_tuned_retrievers_advanced")
EMBEDDINGS_FOLDER_FT = os.path.join(BASE_PATH, "embeddings_finetuned")
EMBEDDINGS_FOLDER_ADVANCED = os.path.join(BASE_PATH, "embeddings_advanced")

# --- Output Paths ---
RESULTS_CSV_OUTPUT_PATH = os.path.join(BASE_PATH, "experiment_2_rerun_with_finetuned_results.csv")
RESULTS_JSON_OUTPUT_FOLDER = os.path.join(BASE_PATH, "experiment_2_retrieval_results_json")
os.makedirs(RESULTS_JSON_OUTPUT_FOLDER, exist_ok=True)


# --- Champion Retriever Configurations (from new Experiment 1) ---
CHAMPION_CONFIGS = [
    {
        "name": "e5-large-v2_FT_Advanced_parent",
        "model_path": os.path.join(ADVANCED_FINETUNED_FOLDER, "e5-large-v2"),
        "embeddings_folder": EMBEDDINGS_FOLDER_ADVANCED,
        "context_key": "parent"
    },
    {
        "name": "e5-large-v2_FT_parent",
        "model_path": os.path.join(FINETUNED_RETRIEVER_FOLDER, "e5-large-v2"),
        "embeddings_folder": EMBEDDINGS_FOLDER_FT,
        "context_key": "parent"
    }
]


# --- Enriched Experiment Hyperparameters ---
K_INITIAL = 100
K_FINAL = 20
PARENT_BONUS_OPTIONS = [0.01, 0.02, 0.05]
CITES_BONUS_OPTIONS = [0.01, 0.02]
CITED_BY_BONUS_OPTIONS = [0.02, 0.04]
SIBLING_BONUS_OPTIONS = [0.005, 0.01]
PAGERANK_WEIGHT_OPTIONS = [0.0, 0.1, 0.2]
ISOLATION_PENALTY_OPTIONS = [0.0, -0.01]

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Load Static Components ---
print("Loading static components...")
with open(GRAPH_PATH, "rb") as f: G = pickle.load(f)
with open(TEST_SET_PATH, "r", encoding="utf-8") as f: test_data = json.load(f)
qid_to_question = {q["QuestionID"]: q["Question"] for q in test_data}
print(f"Loaded {len(test_data)} test questions.")

# --- Pre-calculate PageRank ---
print("Pre-calculating PageRank for all graph nodes...")
pagerank_scores = nx.pagerank(G, alpha=0.85)
print("PageRank calculation complete.")

# Load QRELs
qrel = {}
with open(QREL_PATH, "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split()
        qid, _, uid, rel = parts[0], parts[1], " ".join(parts[2:-1]), int(parts[-1])
        qrel.setdefault(qid, {})[uid] = rel
print(f"Loaded QRELs for {len(qrel)} queries.")

# Prepare BM25
print("Preparing BM25 index...")
all_passage_uids = [n for n, d in G.nodes(data=True) if d.get("type") == "Passage"]
uid_map = {f"{G.nodes[uid].get('document_id')}|||{G.nodes[uid].get('passage_id')}": uid for uid in all_passage_uids}
corpus_texts = [G.nodes[uid].get("text", "") for uid in all_passage_uids]
tokenized_corpus = [text.split() for text in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)
print("BM25 ready.")


# --- Helper Functions ---
def add_instruction_to_query(query, model_name):
    if "e5" in model_name:
        return f"query: {query}"
    if "bge" in model_name:
        return f"Represent this sentence for searching relevant passages: {query}"
    return query

def evaluate_run(run_dict, qrel_dict, metrics={"recall_10", "map_cut_10", "ndcg_cut_10"}):
    evaluator = pytrec_eval.RelevanceEvaluator(qrel_dict, metrics)
    results = evaluator.evaluate(run_dict)
    agg = {metric: pytrec_eval.compute_aggregated_measure(metric, [r.get(metric, 0.0) for r in results.values()]) for metric in metrics}
    return agg

def reciprocal_rank_fusion(ranked_lists, k=60):
    fused = {}
    for lst in ranked_lists:
        for rank, uid in enumerate(lst):
            fused[uid] = fused.get(uid, 0) + 1 / (k + rank + 1)
    return sorted(fused.items(), key=lambda x: x[1], reverse=True)

def format_run_for_json(run_dict, qid_to_question_map, uid_to_internal_uid_map, graph, top_n=10):
    output_list = []
    for qid, passages in run_dict.items():
        sorted_passages = sorted(passages.items(), key=lambda item: item[1], reverse=True)

        retrieved_passages_text = []
        retrieved_scores = []
        retrieved_ids = []

        for combined_uid, score in sorted_passages[:top_n]:
            internal_uid = uid_to_internal_uid_map.get(combined_uid)
            if internal_uid:
                retrieved_passages_text.append(graph.nodes[internal_uid].get("text", ""))
                retrieved_scores.append(score)
                retrieved_ids.append(internal_uid)

        output_list.append({
            "QuestionID": qid,
            "Question": qid_to_question_map.get(qid, ""),
            "RetrievedPassages": retrieved_passages_text,
            "RetrievedScores": retrieved_scores,
            "RetrievedIDs": retrieved_ids
        })
    return output_list

# --- Main Experiment Loop ---
all_results = []
hyperparameter_combinations = list(itertools.product(
    PARENT_BONUS_OPTIONS, CITES_BONUS_OPTIONS, CITED_BY_BONUS_OPTIONS,
    SIBLING_BONUS_OPTIONS, PAGERANK_WEIGHT_OPTIONS, ISOLATION_PENALTY_OPTIONS
))

for config in CHAMPION_CONFIGS:
    champion_name = config["name"]
    model_path = config["model_path"]
    embeddings_folder = config["embeddings_folder"]
    context_key = config["context_key"]
    # Extract the base model key for finding the right embedding folder
    model_key_parts = champion_name.split('_')
    model_key = "_".join(model_key_parts[:-1]) if len(model_key_parts) > 1 else champion_name


    print("\n" + "="*80)
    print(f"--- TESTING CHAMPION RETRIEVER: {champion_name} ---")
    print("="*80)

    # Load Embeddings & Query Encoder
    print(f"Loading embeddings from: {embeddings_folder}")
    emb_path = os.path.join(embeddings_folder, model_key, context_key, "embeddings.pkl")
    id_path = os.path.join(embeddings_folder, model_key, context_key, "passage_ids.json")
    try:
        with open(emb_path, "rb") as f: passage_embeddings = pickle.load(f)
        with open(id_path, "r") as f: passage_ids = json.load(f)
        embeddings_tensor = torch.tensor(passage_embeddings).to(device)
        query_encoder = SentenceTransformer(model_path, device=device)
        print("Champion retriever components loaded successfully.")
    except FileNotFoundError:
        print(f"FATAL ERROR: Embeddings not found at {emb_path}. Skipping this champion.")
        continue

    # Pre-calculate initial retrievals
    print("Pre-calculating initial hybrid retrievals...")
    initial_retrievals = {}
    for q in tqdm(test_data, desc=f"Initial Retrieval for {champion_name}"):
        qid, query = q["QuestionID"], q["Question"]
        instructed_query = add_instruction_to_query(query, champion_name)
        query_emb = query_encoder.encode(instructed_query, convert_to_tensor=True, device=device)
        cos_scores = util.pytorch_cos_sim(query_emb, embeddings_tensor)[0]
        top_dense = torch.topk(cos_scores, k=min(K_INITIAL, len(passage_ids)))
        dense_uids = [passage_ids[idx] for idx in top_dense.indices]
        uid_to_initial_score = {passage_ids[idx]: score.item() for idx, score in zip(top_dense.indices, top_dense.values)}

        bm25_scores = bm25.get_scores(query.split())
        top_bm25 = np.argsort(bm25_scores)[::-1][:K_INITIAL]
        bm25_uids = [all_passage_uids[i] for i in top_bm25]

        fused_results = reciprocal_rank_fusion([dense_uids, bm25_uids])
        initial_retrievals[qid] = (fused_results, uid_to_initial_score)

    # Loop through hyperparameter combinations
    for params in tqdm(hyperparameter_combinations, desc=f"Evaluating Hyperparameters for {champion_name}"):
        parent_b, cites_b, cited_by_b, sibling_b, pagerank_w, isolation_p = params

        graph_reranked_run = {}
        for q in test_data:
            qid = q["QuestionID"]
            fused_results, uid_to_initial_score = initial_retrievals[qid]

            initial_candidates = [{"internal_uid": uid, "initial_score": uid_to_initial_score.get(uid, 0), "graph_bonus": 0.0} for uid, rrf_score in fused_results]
            retrieved_uids_set = {cand["internal_uid"] for cand in initial_candidates[:K_INITIAL]}

            candidates_to_rerank = initial_candidates[:K_INITIAL]
            uid_to_parents = {cand["internal_uid"]: [p for p in G.predecessors(cand["internal_uid"]) if G.get_edge_data(p, cand["internal_uid"], {}).get("type") == "PARENT_OF"] for cand in candidates_to_rerank}

            for i, cand in enumerate(candidates_to_rerank):
                current_uid = cand["internal_uid"]
                bonus_applied = False

                if cand["internal_uid"] in uid_to_parents and uid_to_parents[cand["internal_uid"]]:
                    parent_uid = uid_to_parents[cand["internal_uid"]][0]
                    if parent_uid in retrieved_uids_set:
                        cand["graph_bonus"] += parent_b
                        bonus_applied = True
                    for j, other_cand in enumerate(candidates_to_rerank):
                        if i != j and other_cand["internal_uid"] in uid_to_parents and uid_to_parents[other_cand["internal_uid"]] and uid_to_parents[other_cand["internal_uid"]][0] == parent_uid:
                            cand["graph_bonus"] += sibling_b
                            bonus_applied = True
                            break

                for neighbor_uid in G.successors(current_uid):
                    if G.get_edge_data(current_uid, neighbor_uid, {}).get("type") == "CITES" and neighbor_uid in retrieved_uids_set:
                        cand["graph_bonus"] += cites_b
                        bonus_applied = True

                for neighbor_uid in G.predecessors(current_uid):
                    if G.get_edge_data(neighbor_uid, current_uid, {}).get("type") == "CITED_BY" and neighbor_uid in retrieved_uids_set:
                        cand["graph_bonus"] += cited_by_b
                        bonus_applied = True

                if not bonus_applied:
                    cand["graph_bonus"] += isolation_p

            for cand in candidates_to_rerank:
                pagerank_bonus = pagerank_scores.get(cand["internal_uid"], 0.0) * pagerank_w
                cand["graph_score"] = cand["initial_score"] + cand["graph_bonus"] + pagerank_bonus

            graph_reranked = sorted(candidates_to_rerank, key=lambda x: x["graph_score"], reverse=True)

            graph_reranked_run[qid] = {}
            for cand in graph_reranked[:K_FINAL]:
                node = G.nodes[cand["internal_uid"]]
                combined_uid = f"{node.get('document_id')}|||{node.get('passage_id')}"
                graph_reranked_run[qid][combined_uid] = float(cand["graph_score"])

        graph_metrics = evaluate_run(graph_reranked_run, qrel)
        all_results.append({
            "champion_retriever": champion_name, "parent_bonus": parent_b, "cites_bonus": cites_b,
            "cited_by_bonus": cited_by_b, "sibling_bonus": sibling_b, "pagerank_weight": pagerank_w,
            "isolation_penalty": isolation_p, **graph_metrics
        })

        # Save JSON output for this run
        param_str = f"p{parent_b}_c{cites_b}_cb{cited_by_b}_s{sibling_b}_pg{pagerank_w}_ip{isolation_p}"
        json_output_path = os.path.join(RESULTS_JSON_OUTPUT_FOLDER, f"{champion_name}_{param_str}.json")
        json_data = format_run_for_json(graph_reranked_run, qid_to_question, uid_map, G)
        with open(json_output_path, 'w') as f:
            json.dump(json_data, f, indent=4)


# --- Save and Display Final Results ---
results_df = pd.DataFrame(all_results)
results_df = results_df.rename(columns={"recall_10": "Recall@10", "map_cut_10": "MAP@10"})
results_df = results_df.sort_values(by=["Recall@10", "MAP@10"], ascending=False)

print("\n📊 Final Evaluation Results:")
print(results_df.to_string(index=False, float_format="%.4f"))

results_df.to_csv(RESULTS_CSV_OUTPUT_PATH, index=False)
print(f"\n✅ CSV summary saved to: {RESULTS_CSV_OUTPUT_PATH}")
print(f"✅ Detailed JSON results saved to: {RESULTS_JSON_OUTPUT_FOLDER}")

print("\n--- 🏆 Best Performing Configuration ---")
print(results_df.iloc[0])


Using device: cuda
Loading static components...
Pre-calculating PageRank for all graph nodes...
PageRank calculation complete.
Loaded 447 test questions.
Loaded QRELs for 447 queries.
Preparing BM25 index...
BM25 ready.

--- TESTING CHAMPION RETRIEVER: e5-large-parent_child ---
Loading champion embeddings: e5-large / parent_child
Champion embeddings loaded successfully.
Loading champion query encoder: intfloat/e5-large-v2

Starting enriched graph re-ranking experiment with 144 combinations for e5-large-parent_child.
Pre-calculating initial hybrid retrievals for all queries...


  return forward_call(*args, **kwargs)
Initial Retrieval for e5-large-parent_child: 100%|██████████| 447/447 [01:07<00:00,  6.65it/s]
Evaluating Hyperparameters for e5-large-parent_child: 100%|██████████| 144/144 [03:37<00:00,  1.51s/it]



--- TESTING CHAMPION RETRIEVER: e5-large-full_neighborhood ---
Loading champion embeddings: e5-large / full_neighborhood
Champion embeddings loaded successfully.
Loading champion query encoder: intfloat/e5-large-v2

Starting enriched graph re-ranking experiment with 144 combinations for e5-large-full_neighborhood.
Pre-calculating initial hybrid retrievals for all queries...


Initial Retrieval for e5-large-full_neighborhood: 100%|██████████| 447/447 [00:59<00:00,  7.45it/s]
Evaluating Hyperparameters for e5-large-full_neighborhood: 100%|██████████| 144/144 [03:32<00:00,  1.48s/it]


📊 Final Evaluation Results:
        champion_retriever  parent_bonus  cites_bonus  cited_by_bonus  sibling_bonus  pagerank_weight  isolation_penalty  Recall@10  MAP@10  nDCG@10
e5-large-full_neighborhood        0.0100       0.0100          0.0400         0.0050           0.0000             0.0000     0.3720  0.2628   0.3400
e5-large-full_neighborhood        0.0100       0.0100          0.0200         0.0050           0.1000             0.0000     0.3720  0.2628   0.3400
e5-large-full_neighborhood        0.0100       0.0100          0.0200         0.0050           0.2000             0.0000     0.3720  0.2628   0.3400
e5-large-full_neighborhood        0.0100       0.0200          0.0200         0.0050           0.1000             0.0000     0.3720  0.2628   0.3400
e5-large-full_neighborhood        0.0100       0.0200          0.0400         0.0050           0.0000             0.0000     0.3720  0.2628   0.3400
e5-large-full_neighborhood        0.0100       0.0200          0.0400        




In [None]:
## Import up sound alert dependencies
from IPython.display import Audio, display

def allDone():
  #display(Audio(url='https://www.myinstants.com/media/sounds/anime-wow-sound-effect.mp3', autoplay=True))
  display(Audio(url='https://www.myinstants.com/media/sounds/money-soundfx.mp3', autoplay=True))
## Insert whatever audio file you want above

allDone()