<a href="https://colab.research.google.com/github/RegNLP/ContextAware-Regulatory-GraphRAG-ObliQAMP/blob/main/02_Generate_Embedding_Variants_MultiModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================================================
# 02_Generate_Embeddings_with_Advanced_Models.ipynb
#
# Purpose:
# 1. Load both fine-tuned and pre-trained retriever models.
# 2. Apply instruction prefixes where necessary for models like e5 and bge.
# 3. Generate a new, high-quality set of passage embeddings for each model
#    and context strategy for a full comparison.
# ==============================================================================

# !pip install -q -U sentence-transformers transformers networkx

import os
import json
import pickle
import networkx as nx
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

# --- Config ---
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/RIRAG-MultiPassage-NLLP/"
GRAPH_PATH = os.path.join(BASE_PATH, "graph.gpickle")
# Point to the folder where your new fine-tuned models are saved
FINETUNED_RETRIEVER_FOLDER = os.path.join(BASE_PATH, "fine_tuned_retrievers_advanced")
# Save new embeddings to a new folder to avoid overwriting old results
OUTPUT_FOLDER = os.path.join(BASE_PATH, "embeddings_advanced")
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# --- Load Graph ---
print("Loading graph...")
with open(GRAPH_PATH, "rb") as f:
    G = pickle.load(f)
print("Graph loaded successfully.")

# --- Models to Use for Embedding ---
# Includes both the new advanced fine-tuned models and their pre-trained originals
MODELS_TO_USE = {
    "e5-large-v2_FT_Advanced": os.path.join(FINETUNED_RETRIEVER_FOLDER, "e5-large-v2"),
    "all-mpnet-base-v2_FT_Advanced": os.path.join(FINETUNED_RETRIEVER_FOLDER, "all-mpnet-base-v2"),
    "bge-base-en-v1.5_FT_Advanced": os.path.join(FINETUNED_RETRIEVER_FOLDER, "bge-base-en-v1.5"),
    "e5-large-v2_Pretrained": "intfloat/e5-large-v2",
    "all-mpnet-base-v2_Pretrained": "sentence-transformers/all-mpnet-base-v2",
    "bge-base-en-v1.5_Pretrained": "BAAI/bge-base-en-v1.5",
}

# --- Neighbor Configurations ---
neighbor_configs = {
    "passage_only": lambda G, node: [],
    "parent": lambda G, node: list(G.predecessors(node)),
    "parent_child": lambda G, node: list(G.predecessors(node)) + list(G.successors(node)),
    "full_neighborhood": lambda G, node: list(nx.neighbors(G, node))
}

# --- Helper Functions ---
def add_instruction(text, model_key):
    # Only add "passage: " prefix for e5 models, as BGE does not require it for passages
    if "e5" in model_key:
        return f"passage: {text}"
    return text

def build_contextual_text(G, node_id, get_neighbors_func, model_key):
    """
    Constructs the text for embedding, applying instruction prefixes.
    """
    # Get the base text and apply instruction
    base_text = G.nodes[node_id].get("text", "")
    instructed_base_text = add_instruction(base_text, model_key)

    context_parts = [instructed_base_text]
    neighbors = get_neighbors_func(G, node_id)
    for neighbor_id in neighbors:
        context_text = G.nodes[neighbor_id].get("text", "")
        if context_text:
            # Apply instruction to neighbor text as well
            instructed_context = add_instruction(context_text, model_key)
            context_parts.append(instructed_context)

    return "\n".join(context_parts)

# --- Run All Combinations ---
for model_key, model_path in MODELS_TO_USE.items():
    print("\n" + "="*80)
    print(f"--- Loading Model: {model_key} from {model_path} ---")
    print("="*80)

    try:
        model = SentenceTransformer(model_path)
    except Exception as e:
        print(f"‚ö†Ô∏è Could not load model {model_key}. Skipping. Error: {e}")
        continue

    for config_key, get_neighbors in neighbor_configs.items():
        print(f"\nüîç Generating embeddings: Model={model_key}, Context={config_key}")

        out_dir = os.path.join(OUTPUT_FOLDER, model_key, config_key)
        os.makedirs(out_dir, exist_ok=True)

        texts_to_encode, uids_to_save = [], []

        print("Preparing texts from graph nodes...")
        for node_id in tqdm(G.nodes, desc="Finding Passages"):
            if G.nodes[node_id].get("type") == "Passage":
                full_text = build_contextual_text(G, node_id, get_neighbors, model_key)
                texts_to_encode.append(full_text)
                uids_to_save.append(node_id)

        if not texts_to_encode:
            print("‚ö†Ô∏è Warning: No passages found to embed. Skipping.")
            continue

        print(f"Found {len(texts_to_encode)} passages to embed.")
        embeddings = model.encode(texts_to_encode, show_progress_bar=True, batch_size=32)

        # Save the generated embeddings and corresponding UIDs
        with open(os.path.join(out_dir, "passage_ids.json"), "w") as f:
            json.dump(uids_to_save, f)
        with open(os.path.join(out_dir, "embeddings.pkl"), "wb") as f:
            pickle.dump(embeddings, f)

        print(f"‚úÖ Saved: {out_dir}")

print("\nAll embedding generation tasks are complete.")


In [None]:
## Import up sound alert dependencies
from IPython.display import Audio, display

def allDone():
  #display(Audio(url='https://www.myinstants.com/media/sounds/anime-wow-sound-effect.mp3', autoplay=True))
  display(Audio(url='https://www.myinstants.com/media/sounds/money-soundfx.mp3', autoplay=True))
## Insert whatever audio file you want above

allDone()