<a href="https://colab.research.google.com/github/RegNLP/ContextAware-Regulatory-GraphRAG-ObliQAMP/blob/main/2_Generate_Embedding_Variants_MultiModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================
# 02_Generate_Embedding_Variants_MultiModel.ipynb (Corrected)
#
# Generate four types of embeddings using two models.
# Fix: Corrected the case sensitivity for the node type check.
# ==============================================

# !pip install -q sentence-transformers networkx

import os
import json
import pickle
import networkx as nx
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

# --- Config ---
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/RIRAG-MultiPassage-NLLP/"
GRAPH_PATH = os.path.join(BASE_PATH, "graph.gpickle")
OUTPUT_FOLDER = os.path.join(BASE_PATH, "embeddings")
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# --- Load Graph ---
print("Loading graph...")
with open(GRAPH_PATH, "rb") as f:
    G = pickle.load(f)
print("Graph loaded successfully.")

# --- Embedding Models ---
embedding_models = {
    "mpnet": "sentence-transformers/all-mpnet-base-v2",
    "e5-large": "intfloat/e5-large-v2"
}

# --- Neighbor Configurations ---
neighbor_configs = {
    "passage_only": lambda G, node: [],
    "parent": lambda G, node: list(G.predecessors(node)),
    "parent_child": lambda G, node: list(G.predecessors(node)) + list(G.successors(node)),
    "full_neighborhood": lambda G, node: list(G.predecessors(node)) + list(G.successors(node)) + [n for n in G.neighbors(node)]
}

# --- Function to build contextualized text ---
def build_contextual_text(G, node_id, neighbors):
    """
    Constructs the text for embedding by combining the base text of a node
    with the text of its specified neighbors.
    """
    base_text = G.nodes[node_id].get("text", "")
    context_parts = [base_text]
    for neighbor_id in neighbors:
        context_text = G.nodes[neighbor_id].get("text", "")
        if context_text:
            context_parts.append(context_text)
    return "\n".join(context_parts)

# --- Run All Combinations ---
for model_key, model_name in embedding_models.items():
    print(f"\n📌 Loading model: {model_name}")
    model = SentenceTransformer(model_name)

    for config_key, get_neighbors in neighbor_configs.items():
        print(f"\n🔍 Generating embeddings: Model={model_key}, Context={config_key}")
        texts_to_encode, uids_to_save = [], []

        print("Preparing texts from graph nodes...")
        # Iterate through all nodes to find the passages
        for node_id in tqdm(G.nodes, desc="Finding Passages"):
            # CORRECTED: Check for "Passage" with an uppercase 'P'
            if G.nodes[node_id].get("type") == "Passage":
                neighbors = get_neighbors(G, node_id)
                full_text = build_contextual_text(G, node_id, neighbors)
                texts_to_encode.append(full_text)
                uids_to_save.append(node_id)

        # Check if any passages were found before encoding
        if not texts_to_encode:
            print(f"⚠️  Warning: No nodes of type 'Passage' found in the graph. Skipping embedding generation for this config.")
            continue

        print(f"Found {len(texts_to_encode)} passages to embed.")
        embeddings = model.encode(texts_to_encode, show_progress_bar=True, batch_size=32)

        # Save the generated embeddings and corresponding UIDs
        out_dir = os.path.join(OUTPUT_FOLDER, model_key, config_key)
        os.makedirs(out_dir, exist_ok=True)

        with open(os.path.join(out_dir, "passage_ids.json"), "w") as f:
            json.dump(uids_to_save, f)

        # Use pickle for saving numpy arrays
        with open(os.path.join(out_dir, "embeddings.pkl"), "wb") as f:
            pickle.dump(embeddings, f)

        print(f"✅ Saved: {out_dir}")

print("\nAll embedding generation tasks are complete.")


In [None]:
## Import up sound alert dependencies
from IPython.display import Audio, display

def allDone():
  #display(Audio(url='https://www.myinstants.com/media/sounds/anime-wow-sound-effect.mp3', autoplay=True))
  display(Audio(url='https://www.myinstants.com/media/sounds/money-soundfx.mp3', autoplay=True))
## Insert whatever audio file you want above

allDone()