# Initial Loading (From M1)

In [None]:
from scipy.sparse import load_npz

load_path = "extracted_output/A_co_occur.npz"

A_co_occur = load_npz(load_path)

print("Loaded type:", type(A_co_occur))
print("Shape:", A_co_occur.shape)


In [None]:
## Load the saved entities

import json
import os

def load_extracted_entities(base_dir="extracted_output"):
    """
    Loads:
      - entities.json              → list of deduplicated entity/event nodes
      - entities_per_chunk_debug.json → raw per-chunk extraction data
    """

    entities_path = os.path.join(base_dir, "entities.json")
    per_chunk_path = os.path.join(base_dir, "entities_per_chunk_debug.json")

    if not os.path.exists(entities_path):
        raise FileNotFoundError(f"Missing file: {entities_path}")

    if not os.path.exists(per_chunk_path):
        raise FileNotFoundError(f"Missing file: {per_chunk_path}")

    with open(entities_path, "r", encoding="utf-8") as f:
        entities = json.load(f)

    with open(per_chunk_path, "r", encoding="utf-8") as f:
        per_chunk = json.load(f)

    print("Loaded:")
    print(" - entities.json")
    print(" - entities_per_chunk_debug.json")

    return entities, per_chunk


# Example usage:
if __name__ == "__main__":
    nodes, nodes_per_chunk = load_extracted_entities()
    print(f"Total nodes loaded: {len(nodes)}")
    print(f"Total chunks loaded: {len(nodes_per_chunk)}")


In [None]:
## Load the saved relations

import os
import json

def load_extracted_relations(base_dir="extracted_output"):
    """
    Loads:
      - relations.json
      - relations_per_chunk_debug.json

    Returns:
      relations: list of all deduplicated relation edges
      per_chunk_relations: list of per-chunk debug relation records
    """

    relations_path = os.path.join(base_dir, "relations.json")
    per_chunk_path = os.path.join(base_dir, "relations_per_chunk_debug.json")

    if not os.path.exists(relations_path):
        raise FileNotFoundError(f"Missing file: {relations_path}")

    if not os.path.exists(per_chunk_path):
        raise FileNotFoundError(f"Missing file: {per_chunk_path}")

    with open(relations_path, "r", encoding="utf-8") as f:
        relations = json.load(f)

    with open(per_chunk_path, "r", encoding="utf-8") as f:
        per_chunk_relations = json.load(f)

    print("Loaded:")
    print(" - relations.json")
    print(" - relations_per_chunk_debug.json")

    return relations, per_chunk_relations


# Example usage
if __name__ == "__main__":
    relations, relations_per_chunk = load_extracted_relations()

    print(f"Total relations loaded: {len(relations)}")
    print(f"Total chunks returned: {len(relations_per_chunk)}")


# 3.3 $C_{prior}$ generation

The goal here is to create a sparse answer key $C_{prior}$ (Our causal prior).

Our Cprior is kinda a high-quality ”training dataset” that tlls us the true causal links in the document.
Since we have no human labels, we must generate the dataset ourselves. This is a self-supervised process.
The core idea is to use a large, powerful ”Teacher” LLM (e.g., gemini-2.5-flash) to perform complex causal reasoning.
The output of this phase is the Cprior (Causal Prior)

## Active Candidate-Set Expansion (ACE)

We cant query LLM with all pairs of nodes: O($N^2$).. So prune them to a list of plausible candidate pairs $E_{prior}$. This small list is what we send to our LLM.. This is done in 2 stages:
1. Structural Filter (GAE)
2. Semantic Filter

### 1. Structural Filter (GAE):

Our initial A_w are myopic (They only contain local 1-hop links within a paragraph). We need an unsupervised way to find pairs (i, j) that are strongly connected via multi-hop paths, as these are highly plausible for causal relationships.

In [None]:
import numpy as np
from scipy.sparse import identity, csr_matrix

# Assuming A_co_occur is the final CSR matrix generated previously.
# If A_co_occur is not defined, this code assumes it has been loaded or created.

def calculate_A_hat(A_co_occur: csr_matrix) -> csr_matrix:
    """
    Calculates the normalized adjacency matrix with self-loops, A_hat.
    A_hat = A_co_occur + I, where I is the Identity Matrix.
    
    Args:
        A_co_occur: The unweighted, symmetric co-occurrence matrix (N x N).
        
    Returns:
        The A_hat matrix in CSR format.
    """
    
    # 1. Get the dimension N
    N = A_co_occur.shape[0]
    
    # 2. Create the Identity Matrix I (in sparse format)
    # The identity matrix is N x N, with ones on the diagonal.
    I = identity(N, dtype=A_co_occur.dtype, format='csr')
    
    # 3. Calculate A_hat = A_co-occur + I
    # Sparse matrix addition handles the union of the two sets of non-zero entries.
    A_hat = A_co_occur + I
    
    # Ensure the result is still binary (0 or 1), although A_co_occur should not have values > 1.
    # The sum of 1 (from A_co-occur) and 1 (from I) on the diagonal will result in 2.
    # We must ensure the result is clipped back to 1.
    A_hat_binary = A_hat.sign()
    
    # Convert back to the desired CSR integer format
    return A_hat_binary.astype(np.int8)

# --- Example Usage (Assuming A_co_occur is available) ---
A_hat = calculate_A_hat(A_co_occur)

print(f"✅ A_hat successfully created.")
print(f"A_hat Shape: {A_hat.shape}")
print(f"Non-Zero Entries (nnz): {A_hat.nnz}")

In [None]:
import numpy as np
from scipy.sparse import csr_matrix, diags
from typing import List

def calculate_D_hat(A_hat: csr_matrix) -> csr_matrix:
    """
    Calculates the diagonal degree matrix D_hat from A_hat.
    
    Fix: Ensures the degree vector is explicitly flattened to 1D before diags().
    """
    
    # 1. Calculate Row Sums (The Degree Vector)
    # A_hat.sum(axis=1) returns a dense numpy.matrix of shape (N, 1).
    sum_matrix = A_hat.sum(axis=1)
    
    # 2. Flatten the result into a true 1D NumPy array (shape (N,))
    # The .A1 attribute is the most robust way to flatten the result of a sparse matrix sum.
    degree_vector = sum_matrix.A1 
    
    # 3. Construct the Diagonal Matrix D_hat
    # diags now receives a proper 1D array for the diagonal.
    D_hat = diags(degree_vector, format='csr')
    
    print("Degree Vector Shape (N):", degree_vector.shape)
    print(f"D_hat diagonal values (first 5): {degree_vector[:5].tolist()}")
    
    return D_hat

# --- Example Usage ---
D_hat = calculate_D_hat(A_hat)

print(f"✅ D_hat matrix created.")
print(f"D_hat Shape: {D_hat.shape}")

In [None]:
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import Dict, List, Any


# =========================================================
# 1. Load nodes (entities) from entities.json
# =========================================================

def load_nodes(path: str = "extracted_output/entities.json") -> List[Dict[str, Any]]:
    """
    Load the list of nodes produced by the entity extraction step.
    Each node is expected to look like:

    {
        "id": "N8",
        "name": "Black Market Sales of Titan's strategic assets",
        "type": "event",
        "time": "around July 10th",
        "location": null,
        "description": "A significant increase ...",
        "source_chunks": [0]
    }
    """
    with open(path, "r", encoding="utf-8") as f:
        nodes = json.load(f)
    return nodes


# =========================================================
# 2. Build node texts for embedding (using rich metadata)
# =========================================================

def make_node_text(node: Dict[str, Any]) -> str:
    """
    Build a compact text representation of a node using multiple fields,
    not just the name. This is what we embed.

    Fields used:
      - name (always)
      - type (if present)
      - time (if present)
      - location (if present)
      - description (if present)
    """
    parts = [node.get("name", "").strip()]

    node_type = node.get("type")
    if node_type:
        parts.append(f"type: {node_type}")

    time_str = node.get("time")
    if time_str:
        parts.append(f"time: {time_str}")

    loc = node.get("location")
    if loc:
        parts.append(f"location: {loc}")

    desc = node.get("description")
    if desc:
        parts.append(desc.strip())

    # Filter out any empty components and join with " | "
    parts = [p for p in parts if p]
    return " | ".join(parts)


# =========================================================
# 3. Build NodeEmbeddingMap and X matrix (with N0)
# =========================================================

def build_node_embeddings(
    nodes: List[Dict[str, Any]],
    model_name: str = "all-MiniLM-L6-v2",
) -> (Dict[str, np.ndarray], np.ndarray, List[str]):
    """
    Given the rich node list (with id/name/type/time/location/description),
    build:

      - NodeEmbeddingMap: { "N0": vec0, "N1": vec1, ... }
      - X: 2D matrix where row index == numeric node ID (0..max_id).
            * row 0 -> N0 (all zeros)
            * row k -> Nk
      - sorted_node_ids: list of IDs in numeric order: ["N0", "N1", "N2", ...]

    We *do not* mutate the input nodes list; we just build a consistent mapping.
    """

    # 1) Sort real nodes by numeric id (N1, N2, ..., N14)
    def numeric_id(n: Dict[str, Any]) -> int:
        return int(n["id"][1:])  # "N10" -> 10

    real_nodes_sorted = sorted(nodes, key=numeric_id)

    # 2) Create the list of IDs including N0 as a virtual node
    # N0 is not in 'nodes'; we add it logically here.
    sorted_node_ids: List[str] = ["N0"] + [n["id"] for n in real_nodes_sorted]

    # 3) Build texts for embedding
    # N0 gets a dummy label; we will zero out its embedding later anyway.
    node_texts: List[str] = ["GRAPH_PAD_NODE"]  # for N0

    for n in real_nodes_sorted:
        node_texts.append(make_node_text(n))

    # 4) Load embedder and compute embeddings
    model = SentenceTransformer(model_name)
    print(f"Loaded embedding model: {model_name}")
    print(f"Embedding dimension: {model.get_sentence_embedding_dimension()}")
    print("Starting node embedding...")

    embeddings = model.encode(
        node_texts,
        show_progress_bar=True,
        convert_to_numpy=True,
    )

    # 5) Overwrite N0 vector with zeros
    embeddings[0, :] = 0.0
    print("\n✅ N0 embedding set to zero.")

    # 6) Build NodeEmbeddingMap: { node_id: vector }
    NodeEmbeddingMap: Dict[str, np.ndarray] = {
        node_id: embeddings[i] for i, node_id in enumerate(sorted_node_ids)
    }

    # 7) Build X matrix where row index == numeric ID
    #    Highest numeric id among real nodes:
    max_id = max(int(n["id"][1:]) for n in real_nodes_sorted)
    d = embeddings.shape[1]

    # We want rows 0..max_id, so shape = (max_id + 1, d)
    X = np.zeros((max_id + 1, d), dtype=embeddings.dtype)

    # Row 0: N0
    X[0, :] = NodeEmbeddingMap["N0"]

    # Rows for N1..Nmax
    for n in real_nodes_sorted:
        idx = int(n["id"][1:])
        X[idx, :] = NodeEmbeddingMap[n["id"]]

    # Some prints for verification
    print("\n--- Node Embedding Results ---")
    print(f"Total real nodes (no N0): {len(real_nodes_sorted)}")
    print(f"Total logical nodes (with N0): {len(sorted_node_ids)}")
    print(f"Embedding matrix X shape: {X.shape}")
    print(f"Verification: sum(N0 row) = {np.sum(X[0]):.4f}")

    return NodeEmbeddingMap, X, sorted_node_ids


# =========================================================
# 4. Optional: pretty-print a small snippet
# =========================================================

def debug_print_node_embeddings(
    NodeEmbeddingMap: Dict[str, np.ndarray],
    sorted_node_ids: List[str],
    max_nodes_to_show: int = 5,
):
    """
    Print the first few node IDs and a snippet of their embedding vectors.
    """
    print("\n--- NodeEmbeddingMap (snippet) ---")
    for node_id in sorted_node_ids[:max_nodes_to_show]:
        vec_snip = NodeEmbeddingMap[node_id][:3]
        print(f"{node_id}: [{vec_snip[0]:.4f}, {vec_snip[1]:.4f}, {vec_snip[2]:.4f}, ...]")


# =========================================================
# 5. Build X and NodeEmbeddingMap end-to-end
# =========================================================

if __name__ == "__main__":
    # 1) Load nodes from the GraphRAG-style entities file
    nodes = load_nodes("extracted_output/entities.json")

    # 2) Build embeddings
    NodeEmbeddingMap, X, sorted_node_ids = build_node_embeddings(nodes)

    # 3) Debug print
    debug_print_node_embeddings(NodeEmbeddingMap, sorted_node_ids, max_nodes_to_show=6)

    # 4)  Save X and NodeEmbeddingMap if you want to reuse later
    os.makedirs("extracted_output", exist_ok=True)

    # Save X as numpy binary
    np.save("extracted_output/node_features_X.npy", X)

    # Save NodeEmbeddingMap as a JSON + npy (store vectors in a separate .npy if you want)
    # Here we just store IDs → index mapping; vectors are in X.
    id_to_row_index = {node_id: int(node_id[1:]) if node_id != "N0" else 0
                       for node_id in sorted_node_ids}

    with open("extracted_output/node_id_to_row_index.json", "w", encoding="utf-8") as f:
        json.dump(id_to_row_index, f, indent=2, ensure_ascii=False)

    print("\nSaved:")
    print("  extracted_output/node_features_X.npy")
    print("  extracted_output/node_id_to_row_index.json")


#### Pre-computation and GAE Model Definition:

This section defines the helper function to calculate the normalized adjacency matrix $\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}}$ and the PyTorch module for the GAE.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GAE
from torch_geometric.data import Data
from torch_geometric.utils import from_scipy_sparse_matrix
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.preprocessing import StandardScaler

# --- Hyperparameters and Dimensions ---
N_NODES = len(nodes)+1
D_IN = 384
D_HIDDEN = 128
D_LATENT = 64
LEARNING_RATE = 0.005
WEIGHT_DECAY = 5e-4
NUM_EPOCHS = 500   # 500 is enough to see behaviour


# --- 1. Model Definitions ---

class GCNEncoder(nn.Module):
    """Two-layer GCN encoder for GAE."""
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels, cached=True)
        self.conv2 = GCNConv(hidden_channels, out_channels, cached=True)

        # Optional: explicit Glorot init (PyG already uses good defaults)
        self.reset_parameters()

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


# --- 2. Data Preparation (WITH FEATURE SCALING) ---

def prepare_pyg_data(X_np: np.ndarray, A_cooccur_csr: csr_matrix):
    """
    Converts NumPy/SciPy data into the PyTorch Geometric Data object,
    applying necessary feature scaling (StandardScaler) for stability.
    """

    # Scale node features
    scaler = StandardScaler()
    X_np_scaled = scaler.fit_transform(X_np)

    # ---- IMPORTANT: Make adjacency symmetric & remove self loops ----
    A = A_cooccur_csr
    # Symmetrize: A_undirected = A OR A^T
    A_sym = ((A + A.T) > 0).astype(np.float32)

    # Let GCNConv add self-loops itself; don't add I here
    # Use PyG helper to create edge_index
    edge_index, _ = from_scipy_sparse_matrix(A_sym)

    x = torch.from_numpy(X_np_scaled).float()
    data = Data(x=x, edge_index=edge_index)

    return data


# --- 3. Training Function ---

def train_pyg_gae(data: Data, epochs: int):
    
    encoder = GCNEncoder(
        in_channels=D_IN,
        hidden_channels=D_HIDDEN,
        out_channels=D_LATENT
    )
    model = GAE(encoder)

    optimizer = optim.Adam(model.parameters(),
                           lr=LEARNING_RATE,
                           weight_decay=WEIGHT_DECAY)

    print(f"Training GAE for {epochs} epochs (lr={LEARNING_RATE}, decay={WEIGHT_DECAY})...")

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()

        z = model.encode(data.x, data.edge_index)

        # recon_loss expects a positive edge_index of the (unweighted) graph
        loss = model.recon_loss(z, data.edge_index)

        loss.backward()
        optimizer.step()

        if epoch % 50 == 0 or epoch == 1:
            print(f"Epoch {epoch:03d}, Loss: {loss.item():.4f}")

    # Return final embeddings
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x, data.edge_index)

    return z


# --- 4. Execution Block ---

# Option A: your original random graph (will not give meaningful loss trend)
X_np_sim = np.random.rand(N_NODES, D_IN).astype(np.float32)
A_cooccur_sim = csr_matrix(
    np.random.randint(0, 2, size=(N_NODES, N_NODES), dtype=np.int8)
)

# Make sure nonzero entries are 1
A_cooccur_sim.data[:] = 1

pyg_data = prepare_pyg_data(X_np_sim, A_cooccur_sim)

Final_Embeddings_Z = train_pyg_gae(pyg_data, epochs=NUM_EPOCHS)

print("-" * 40)
print("✅ PyG GAE Training Successfully Executed.")
print(f"Final Latent Embedding Matrix Z Shape: {Final_Embeddings_Z.shape}")


In [None]:
Z = Final_Embeddings_Z

In [None]:
## Saving the above

import os
import torch

# Make sure directory exists
os.makedirs("extracted_output", exist_ok=True)

save_path = "extracted_output/Z.pt"

# Z = Final_Embeddings_Z  # Ensure Z exists
torch.save(Z, save_path)

print(f"Saved GAE embeddings Z to: {save_path}")


In [None]:
## loading from the above

import torch

load_path = "extracted_output/Z.pt"

# Load the saved tensor
Z = torch.load(load_path)

print(f"Loaded GAE embeddings Z from: {load_path}")
print(f"Z shape: {Z.shape}")
print(f"Z dtype: {Z.dtype}")


In [None]:
import os
import json
from typing import List, Dict, Any, Tuple, Set

import torch
import numpy as np
import faiss


# ============================================================
# CONFIG
# ============================================================

Z_PATH = "extracted_output/Z.pt"          # Saved GAE embeddings
NODES_PATH = "extracted_output/entities.json"  # GraphRAG-style nodes

SCALING_RATIO = 0.30  # k' = floor(num_real_nodes * SCALING_RATIO)


# ============================================================
# 1. LOAD NODES AND BUILD INDEX MAPPINGS
# ============================================================

with open(NODES_PATH, "r", encoding="utf-8") as f:
    nodes: List[Dict[str, Any]] = json.load(f)

# nodes[i]["id"] is like "N1", "N2", ... created by your entity extractor.
# We will assume:
#   - N0 is your placeholder node (row 0 in Z).
#   - N1..Nn correspond to rows 1..n in Z, in natural numeric order.

# Build id -> numeric index mapping: "Nk" -> k  (0 reserved for N0)
node_id_to_index: Dict[str, int] = {}

for node in nodes:
    node_id_str = node["id"]         # e.g. "N10"
    numeric_part = int(node_id_str[1:])  # "10" -> 10
    node_id_to_index[node_id_str] = numeric_part

# Number of REAL nodes = len(nodes). Total rows in Z = real_nodes + 1 (for N0)
NUM_REAL_NODES = len(nodes)
print(f"Loaded {NUM_REAL_NODES} real nodes from entities.json")


# ============================================================
# 2. LOAD GAE EMBEDDINGS Z
# ============================================================

if not os.path.exists(Z_PATH):
    raise FileNotFoundError(f"Z embeddings file not found at: {Z_PATH}")

Z: torch.Tensor = torch.load(Z_PATH)
if not isinstance(Z, torch.Tensor):
    raise TypeError("Loaded object from Z.pt is not a torch.Tensor")

print(f"Loaded Z from {Z_PATH}")
print(f"Z shape: {tuple(Z.shape)}  (rows x latent_dim)")

# Z should have row 0 = N0, rows 1.. = N1..Nn
NUM_ROWS_Z, LATENT_DIM = Z.shape

# Basic consistency check: NUM_ROWS_Z should be NUM_REAL_NODES + 1 (for N0)
if NUM_ROWS_Z != NUM_REAL_NODES + 1:
    print(
        f"[WARN] Z rows ({NUM_ROWS_Z}) != num_real_nodes + 1 "
        f"({NUM_REAL_NODES + 1}). Check your graph construction."
    )

# Convert to float32 numpy for FAISS
Z_np = Z.detach().cpu().numpy().astype("float32")


# ============================================================
# 3. CHOOSE DYNAMIC K' FOR ANN NEIGHBORS
# ============================================================

# We want k' scaled by real nodes, not counting N0.
K_PRIME_SCALED = int(NUM_REAL_NODES * SCALING_RATIO)
K_PRIME = max(1, min(K_PRIME_SCALED, NUM_REAL_NODES))  # can't exceed num real nodes

print(f"Total rows in Z (including N0): {NUM_ROWS_Z}")
print(f"Number of real nodes (excluding N0): {NUM_REAL_NODES}")
print(f"SCALING_RATIO = {SCALING_RATIO}")
print(f"k' (neighbors per node) = {K_PRIME}")


# ============================================================
# 4. BUILD FAISS INDEX AND RUN ANN SEARCH
# ============================================================

print("\nBuilding FAISS Index on Z...")
index = faiss.IndexFlatL2(LATENT_DIM)
index.add(Z_np)

# For each row i in Z, search its K_PRIME + 1 nearest neighbors
# (including itself as the closest).
print("Running ANN search...")
D, I = index.search(Z_np, K_PRIME + 1)
# D, I have shape (NUM_ROWS_Z, K_PRIME + 1)


# ============================================================
# 5. BUILD CANDIDATE SET C1 (IN INDEX SPACE)
# ============================================================

C1: Set[Tuple[int, int]] = set()
raw_candidate_list: List[Dict[str, Any]] = []

print("\nPopulating Candidate Set C1, excluding N0 (index 0)...")

for i in range(NUM_ROWS_Z):
    # i is the current node index (0..NUM_ROWS_Z-1)
    neighbors_i = I[i]  # shape (K_PRIME + 1,)

    # We skip neighbors_i[0] because it's i itself (distance = 0).
    # We also skip any neighbor that is i, or either is 0 (N0).
    for neighbor_idx in neighbors_i[1:]:
        j = int(neighbor_idx)

        # Skip any pair involving placeholder N0 (index 0)
        if i == 0 or j == 0:
            continue

        # Skip self-loop if it ever appears
        if i == j:
            continue

        pair = (i, j)
        if pair in C1:
            continue

        C1.add(pair)

        # Find the position of j within neighbors_i to extract distance
        pos_in_search = np.where(neighbors_i == j)[0][0]
        dist_sq = float(D[i, pos_in_search])  # FAISS returns squared L2
        dist = float(np.sqrt(dist_sq))

        raw_candidate_list.append(
            {
                "source_idx": i,
                "target_idx": j,
                "distance_squared": dist_sq,
                "distance": dist,
            }
        )

# ============================================================
# 6. FINAL REPORT
# ============================================================

print("\n--- Candidate Set C1 Results ---")
print(f"Total rows in Z (including N0): {NUM_ROWS_Z}")
print(f"Number of real nodes (excluding N0): {NUM_REAL_NODES}")
print(f"k' (neighbors searched per node): {K_PRIME}")
print(f"Size of Candidate Set C1 (|C1|): {len(C1)} pairs")

# Optional: show a small snippet of candidates
SNIPPET = 10
print(f"\nSnippet of first {SNIPPET} candidate pairs:")
for candidate in raw_candidate_list[:SNIPPET]:
    i = candidate["source_idx"]
    j = candidate["target_idx"]
    dist = candidate["distance"]
    print(f"  {i} -> {j} | Dist: {dist:.4f}")

# If you want a mapping back to node_ids ("N1", "N2", ...) later,
# you can create it like this (assuming your numeric indices match ids):
index_to_node_id: Dict[int, str] = {}
for node in nodes:
    num = int(node["id"][1:])
    index_to_node_id[num] = node["id"]
index_to_node_id[0] = "N0"  # placeholder

# Example of translating one pair (if any exist):
if raw_candidate_list:
    ex = raw_candidate_list[0]
    i_idx, j_idx = ex["source_idx"], ex["target_idx"]
    print(
        f"\nExample mapping: {i_idx} -> {j_idx} "
        f"== {index_to_node_id.get(i_idx, '?')} -> {index_to_node_id.get(j_idx, '?')}"
    )

# C1 and raw_candidate_list are now ready to be plugged into
# your downstream causal verification / RAG pipeline.


In [None]:
from pprint import pprint

pprint(raw_candidate_list)

In [None]:
K1: List[Tuple[int, int]] = []

for candidate in raw_candidate_list:
    # Extract the source and target indices and store as a tuple (i, j)
    source = candidate['source_idx']
    target = candidate['target_idx']
    
    K1.append((source, target))


In [None]:
pprint(K1)

In [None]:
## Save K1 to the saved stuff directory

import json
import os

# Ensure the directory exists
os.makedirs("extracted_output", exist_ok=True)

# Convert np.int64 → int for JSON serialization
def convert_k1_to_jsonable(K1):
    return [(int(a), int(b)) for (a, b) in K1]

jsonable_K1 = convert_k1_to_jsonable(K1)

save_path = "extracted_output/k1.json"

with open(save_path, "w") as f:
    json.dump(jsonable_K1, f, indent=2)

print(f"Saved K1 to {save_path}")


In [None]:
## Retrieve the K1

import json
import numpy as np

load_path = "extracted_output/k1.json"

with open(load_path, "r") as f:
    loaded_list = json.load(f)

# Convert back to (int, np.int64) tuple format
K1_loaded = [(int(a), np.int64(b)) for (a, b) in loaded_list]

print(f"Loaded K1 from {load_path}")
print(K1_loaded[:10])  # preview

In [None]:
## Retrieve the Z (GAE embeddings)

import torch

load_path = "extracted_output/Z.pt"

Z_loaded = torch.load(load_path, map_location="cpu")  # safe for CPU use

print(f"Loaded Z from: {load_path}")
print("Shape:", Z_loaded.shape)
print("Dtype:", Z_loaded.dtype)

In [None]:
## Retrieve the A_co_occur

from scipy.sparse import load_npz

load_path = "extracted_output/A_co_occur.npz"

A_co_occur_loaded = load_npz(load_path)

print(f"Loaded A_co_occur from: {load_path}")
print(A_co_occur_loaded)


In [None]:
E_prior = K1_loaded
A_co_occur = A_co_occur_loaded
Z = Z_loaded

### Semantic filter 

Here our basic goal is to filter our K1 "structurally plausible set" down to a small K2 "semantically plausible set" 

In [None]:
print("hello") # This part is still left (our search space is small so lets see this later)

### Brief Summary of Pre-CoCaD filteration part:

1. Structural Filter (GAE/ANN): You begin with $N^2$ possible pairs. The GAE/ANN process filters this down to the structurally plausible set, $\mathbf{C}_1$.
2. Input to Semantic Filter: The $\mathbf{C}_1$ set is the full input to the Semantic Filter.
3. Semantic Filter (Verification/CPC): This stage rigorously verifies and classifies the plausibility of each link in $\mathbf{C}_1$. It discards links that are semantically implausible, not supported by context, or lack mechanism/temporality.
4. Final Output: The resulting, highly curated, high-quality set is $\mathbf{K}_2$.$$\mathbf{C}_1 \xrightarrow[\text{Verification, Classification}]{\text{Semantic Filter}} \mathbf{K}_2$$

The final $\mathbf{K}_2$ set represents your ultimate $\mathbf{E}_{\text{prior}}$, containing only the most promising candidates ready for the final causal algorithm.

### Note: The loaded $E_{prior}$ is supposed to be direction-agnostic, meaning if it contains node (i,j) it should contain node (j,i) also..

Reason: Structural Completeness

The primary goal of $E_{\text{prior}}$ is to provide a complete list of plausible hypotheses for the final causal discovery algorithm.
1. Symmetry of Evidence: The initial structural score ($Z_i Z_j^T$) is symmetric. If $i$ is structurally close to $j$, then $j$ is equally close to $i$.
2. Causal Test Requirement: The final causal discovery phase must individually test the hypothesis $i \to j$ against the hypothesis $j \to i$. If you only included $(i, j)$, you would prematurely exclude the possibility that $j$ causes $i$.
Therefore, to maintain structural completeness, both directed hypotheses must be present in $E_{\text{prior}}$.

In [None]:
E_prior = K1_loaded

In [None]:
import numpy as np
from typing import List, Tuple, Set


def enforce_bidirectional_symmetry(initial_E_prior: List[Tuple[int, np.int64]]) -> List[Tuple[int, int]]:
    """
    Takes a list of directed pairs (i, j) and ensures the inverse pair (j, i) 
    also exists, converting all elements to native Python integers (int) for cleanup.
    """
    
    # 1. Convert initial list to a Set of tuples (i, j) for fast lookup and deduplication.
    # Convert np.int64 to standard Python int during this step.
    full_set: Set[Tuple[int, int]] = set()
    
    # First pass: Populate the set with existing and inverse pairs
    for i_np, j_np in initial_E_prior:
        # Convert np.int64 to standard Python int
        i = int(i_np)
        j = int(j_np)
        
        # Add the original pair
        full_set.add((i, j))
        
        # Add the inverse pair (j, i)
        full_set.add((j, i))
        
    # 2. Convert the final Set back to a List
    # Sorting is optional but makes debugging easier and guarantees a deterministic output order.
    final_E_prior = sorted(list(full_set))
    
    return final_E_prior

# Update the E_prior list in place
E_prior = enforce_bidirectional_symmetry(E_prior)

print("--- E_prior Symmetry Enforcement Complete ---")
print(f"Initial size (approx): {29}")
print(f"Final size: {len(E_prior)} (Should be ~2x the number of unique links)")
print("\nSnippet of Final E_prior:")
print(E_prior[:5])
print("...")
print(E_prior[-5:])