Lets start the CausGT-HS core model implementation (A lot behind is left Lets see how it goes)

# 3.4 CausGT-HS: A self-supervised, probabilistic energy-based causal graph token transformer

The CausGT-HS (Causal Graph-Token Hierarchical Self-Supervised) model is the final, deep-learning encoder of our system. It is trained as an Energy-Based Model (EBM) to learn a single, globally coherent, multi-relational causal graph, $G$. Its design is motivated by the need for uncertainty quantification and robust mechanism learning. This is our main "student" GNN that we are gonna distil our $C_{prior}$ into along with the others.

- An Overview: 
The CausGT-HS defines a probability distribution $P(G)$ over all possible causal graphs, where graphs consistent with the evidence have low energy.$$ \\ P(G) \propto e^{-E\_{\psi}(G)}

- Method: 
The model's global energy function, $E_{\psi}(G)$, is trained to minimize the energy of "correct" graphs ($G_{\text{positive}}$, derived from $C_{\text{prior}}$) while maximizing the energy of "incorrect" graphs ($G_{\text{negative}}$). This training is performed via a contrastive divergence loss (InfoNCE).

- Inference (Uncertainty): Unlike deterministic models, CausGT-HS uses MCMC (Monte Carlo Markov Chain) sampling (Langevin Dynamics) to sample an ensemble of plausible graphs, $G_k$, from the low-energy regions of the distribution. This provides principled confidence intervals and structural uncertainty estimates.


## 3.4.1 Inputs to the CausGT-HS model

- Correlational Graphs: 
Our initial adjancency matrices
$$
A_w = \{ W_1, W_2, \ldots, W_k \}
$$

- Rich Causal Prior ($P_{rich}$)

It is the output produced by our CoCaD pipeline. It is a sparse list of tuples: 

(i, j, type, evidence,score, uncertainity_variance).

Above:
    
    - type: A string indicating the causal relation, one of: 'DIRECT', 'MEDIATED' or 'CONFOUNDED'

    - evidence: A list of node IDs (e.g. the mediator [k] or the confounders [$k_c$])

    - score: The final $P_{direct}$ score (e.g., 0.95 for direct, 0 for mediated)

    - uncertainity variance: The statistical variance of the score representing the teachers variance

- Raw Text Data (S):

This is the original doc segmented into sentence snippets. This is the ground truth textual evidence for our tokenphormer to read and reason from.

- GAE Embeddings (Z):

This is our $N \times d_z$ structural node embeddings from our graph autoencoder. This is our initial structural embeddings $H^{(0)}$ which we will refine later on.


## 3.4.2 CausGT-HS model architectrure

The model is an **encoder-only**, hierarchical, *L*-layer stack (lets keep L=3 as of now).  
This encoder, denoted as **$f_θ$**, takes all raw inputs and produces a final **causally-aware graph**:

<p align="center">

$$
G = (H^{(L)},\ A)
$$

</p>

The encoder is composed of three major components:

### • Embedding Layers
Responsible for converting all raw signals (text, structural embeddings, priors, adjacency matrices) into a unified token or node representation.

### • Reasoning Layers
These layers iteratively refine the representations by performing multi-step reasoning over:
- graph structure  
- causal priors  
- textual evidence  
- learned embeddings  

### • Proposer Network
The final module that generates the updated causal graph edges and relations based on the refined representations.


### 3.4.2.1 Multi-faceted Representations:

One embedding per node ain't interesting, nodes can have multiple meanings (polysemy), we must model this.

Dynamic facet selection at runtime:

Each query (eg node i facet k) attends to the M facets of neighbor j.



##### Multi-Embedding Nodes

In [None]:

import os
import json
import math
import time
from typing import List, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from scipy import sparse
from sklearn.cluster import KMeans


DATA_DIR = "./data"  
NODES_JSONL = os.path.join(DATA_DIR, "entities.jsonl")   
TMAP_JSON = os.path.join(DATA_DIR, "tmap.json")        
SENT_EMBS_NPY = os.path.join(DATA_DIR, "sent_embs.npy") 
AW_NPZ = os.path.join(DATA_DIR, "AW.npz")               

M_FACETS = 5
D_MODEL = 256
EPOCHS = 30
BATCH_SIZE = 512
LR = 3e-4
LAMBDA_MAX = 1e-2
WARMUP_EPOCHS = 10
BATCH_NEIGH_K = 32
NUM_WORKERS = 4

# -------------------------
# Utilities
# -------------------------
def load_jsonl(path: str) -> List[dict]:
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                out.append(json.loads(line))
    return out

def load_tmap(path: str) -> Dict[str, List[int]]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def load_sent_embeddings(path: str) -> np.ndarray:
    return np.load(path)

def load_adj_npz(path: str) -> sparse.spmatrix:
    return sparse.load_npz(path)

# -------------------------
# Dataset for nodes (samples node indices and neighbor lists)
# -------------------------
class NodeDataset(Dataset):
    def __init__(self, node_ids: List[int], AW: Optional[sparse.spmatrix]=None, batch_neigh_k: int = 32):
        self.node_ids = node_ids
        self.AW = AW.tocsr() if AW is not None else None
        self.batch_neigh_k = batch_neigh_k

    def __len__(self):
        return len(self.node_ids)

    def __getitem__(self, idx):
        node_idx = self.node_ids[idx]
        neigh = None
        if self.AW is not None:
            row = self.AW.getrow(node_idx).tocoo()
            neigh = row.col
            if len(neigh) > self.batch_neigh_k:
                neigh = np.random.choice(neigh, size=self.batch_neigh_k, replace=False)
        return int(node_idx), neigh

def collate_batch(batch):
    nodes = torch.tensor([b[0] for b in batch], dtype=torch.long)
    neighs = [b[1] for b in batch]
    return nodes, neighs

# -------------------------
# MultiFacetModule
# -------------------------
class MultiFacetModule(nn.Module):
    def __init__(
        self,
        num_nodes: int,
        M: int,
        d_model: int,
        device: torch.device,
        init_facet_embeddings: Optional[np.ndarray] = None,
        use_context_gate: bool = True,
        ctx_dim: int = 384,
    ):
        super().__init__()
        self.N = num_nodes
        self.M = M
        self.d = d_model
        self.device = device

        # facets stored as embedding table (N*M, d)
        self.facet_table = nn.Embedding(self.N * self.M, self.d)
        nn.init.normal_(self.facet_table.weight, mean=0.0, std=0.02)

        # static gating logits (learnable)
        self.gate_logits = nn.Parameter(torch.zeros(self.N, self.M))

        self.use_context_gate = use_context_gate
        if use_context_gate:
            self.ctx_net = nn.Sequential(
                nn.Linear(ctx_dim, ctx_dim),
                nn.ReLU(inplace=True),
                nn.Linear(ctx_dim, self.M),
            )

        # small projections for facet attention (q/k/v)
        self.proj_q = nn.Linear(self.d, self.d, bias=False)
        self.proj_k = nn.Linear(self.d, self.d, bias=False)
        self.proj_v = nn.Linear(self.d, self.d, bias=False)

        if init_facet_embeddings is not None:
            assert init_facet_embeddings.shape == (self.N, self.M, self.d)
            flat = torch.tensor(init_facet_embeddings.reshape(-1, self.d), dtype=torch.float32)
            with torch.no_grad():
                self.facet_table.weight.data.copy_(flat)

    def node_facet_indices(self, node_idxs: Tensor) -> Tensor:
        # returns flattened indices into facet_table for node_idxs (B, M) -> (B*M,)
        base = node_idxs.unsqueeze(1) * self.M
        offsets = torch.arange(self.M, device=self.device).unsqueeze(0)
        return (base + offsets).reshape(-1)

    def fetch_facets_for_nodes(self, node_idxs: Tensor) -> Tensor:
        idxs = self.node_facet_indices(node_idxs)
        facets = self.facet_table(idxs).view(node_idxs.shape[0], self.M, self.d)
        return facets

    def forward(self, node_idxs: Tensor, context_vecs: Optional[Tensor] = None, neighbor_indices: Optional[List[np.ndarray]] = None, return_all=False):
        B = node_idxs.shape[0]
        facets = self.fetch_facets_for_nodes(node_idxs)            # (B, M, d)
        static_logits = self.gate_logits[node_idxs]               # (B, M)
        static_gate = torch.sigmoid(static_logits)                # (B, M)
        if self.use_context_gate and (context_vecs is not None):
            ctx_logits = self.ctx_net(context_vecs)               # (B, M)
            gates = torch.sigmoid(static_logits + ctx_logits)     # (B, M)
        else:
            gates = static_gate

        gates_unsq = gates.unsqueeze(-1)
        H_active = (facets * gates_unsq).sum(dim=1)               # (B, d)

        out = {"H_active": H_active, "facets": facets, "gates": gates}
        # neighbor blending (vectorized across neighbors where possible)
        if neighbor_indices is not None:
            nb_blended = []
            # process in-loop per-batch element (neighbors vary in length)
            for i in range(B):
                neigh = neighbor_indices[i]
                if neigh is None or len(neigh) == 0:
                    nb_blended.append(torch.zeros(self.d, device=self.device))
                    continue
                neigh_tensor = torch.tensor(np.asarray(neigh, dtype=np.int64), dtype=torch.long, device=self.device)
                neigh_facets = self.fetch_facets_for_nodes(neigh_tensor)  # (K, M, d)
                # project
                q = self.proj_q(H_active[i].unsqueeze(0))  # (1, d)
                k = self.proj_k(neigh_facets.view(-1, self.d))  # (K*M, d)
                v = self.proj_v(neigh_facets.view(-1, self.d))  # (K*M, d)
                att_logits = (q @ k.T) / math.sqrt(self.d)  # (1, K*M)
                att = F.softmax(att_logits, dim=-1)
                blended = att @ v
                nb_blended.append(blended.squeeze(0))
            out["neighbor_blended"] = torch.stack(nb_blended, dim=0)  # (B, d)
        if return_all:
            out["static_gate"] = static_gate
        return out

    def l1_gate_cost(self) -> Tensor:
        gate_vals = torch.sigmoid(self.gate_logits)
        return gate_vals.abs().sum() / (self.N * self.M)

    def update_facets_via_clustering(self, nodes: List[dict], Tmap: Dict[str, List[int]], sent_embeddings: np.ndarray, per_node_min_context: int = 2):
        init = np.random.normal(scale=0.02, size=(self.N, self.M, self.d)).astype(np.float32)
        num_sent_dim = sent_embeddings.shape[1]
        for node in nodes:
            nid = node.get("id")
            if isinstance(nid, str) and nid.startswith("N"):
                try:
                    nidx = int(nid[1:])
                except:
                    nidx = int(nid)
            else:
                nidx = int(nid)
            s_indices = Tmap.get(str(nid), []) or Tmap.get(str(nidx), [])
            if not s_indices or len(s_indices) < per_node_min_context:
                continue
            vecs = sent_embeddings[s_indices]  # (S, embed_dim)
            k = min(self.M, max(1, vecs.shape[0]))
            if k == 1:
                centroids = np.vstack([vecs.mean(axis=0) for _ in range(self.M)])
            else:
                km = KMeans(n_clusters=k, random_state=0, n_init=4).fit(vecs)
                centroids = km.cluster_centers_
                if k < self.M:
                    extras = np.tile(centroids[:1], (self.M - k, 1))
                    centroids = np.vstack([centroids, extras])
            if centroids.shape[1] != self.d:
                rng = np.random.RandomState(0)
                P = rng.normal(size=(centroids.shape[1], self.d)).astype(np.float32) * 0.02
                centroids = centroids @ P
            init[nidx] = centroids[:self.M]
        with torch.no_grad():
            flat = torch.tensor(init.reshape(-1, self.d), dtype=torch.float32, device=self.device)
            self.facet_table.weight.data.copy_(flat)

# -------------------------
# Proxy loss: neighbor reconstruction 
# -------------------------
def compute_proxy_causal_loss(module: MultiFacetModule, batch_node_idxs: Tensor, neighbor_lists: List[np.ndarray], AW: sparse.spmatrix, device: torch.device):
    out = module(batch_node_idxs, return_all=False)
    H_active = out["H_active"]  # (B, d)
    neighbor_blended = out.get("neighbor_blended", None)  # (B, d) if present
    losses = []
    B = H_active.shape[0]
    # If neighbor_blended available, compute dot(H_i, blended_Ni) as logits (I'll check this later on)
    if neighbor_blended is not None:
        logits = (H_active * neighbor_blended).sum(dim=-1)  # (B,)
        targets = []
        for i in range(B):
            nid = int(batch_node_idxs[i].item())
            row = AW.getrow(nid).tocoo()
            # presence if any neighbor had weight>0 -> we treat positive target as 1 if row non-empty else 0
            targets.append(1.0 if row.nnz > 0 else 0.0)
        t = torch.tensor(targets, dtype=torch.float32, device=device)
        loss = F.binary_cross_entropy_with_logits(logits, t)
        return loss
    # fallback: per-neighbor BCE as in original but batched
    for i in range(B):
        nid = int(batch_node_idxs[i].item())
        neigh = neighbor_lists[i]
        if neigh is None or len(neigh) == 0:
            continue
        neigh = np.asarray(neigh, dtype=np.int64)
        neigh_tensor = torch.tensor(neigh, dtype=torch.long, device=device)
        neigh_out = module(neigh_tensor, context_vecs=None, return_all=False)
        Hn = neigh_out["H_active"]
        scores = (H_active[i].unsqueeze(0) * Hn).sum(dim=-1)
        row = AW.getrow(nid).tocoo()
        col_to_val = dict(zip(row.col.tolist(), row.data.tolist()))
        targets = torch.tensor([col_to_val.get(int(n), 0.0) for n in neigh], dtype=torch.float32, device=device)
        loss = F.binary_cross_entropy_with_logits(scores, targets)
        losses.append(loss)
    if len(losses) == 0:
        return torch.tensor(0.0, device=device)
    return torch.stack(losses).mean()

# -------------------------
# Annealing schedule
# -------------------------
def lambda_fac_schedule(epoch: int, max_lambda: float, warmup_epochs: int = 10):
    if epoch <= 0:
        return 0.0
    if epoch < warmup_epochs:
        return max_lambda * (epoch / warmup_epochs)
    return max_lambda

# -------------------------
# Training loop 
# -------------------------
def train_loop(
    module: MultiFacetModule,
    nodes: List[dict],
    tmap: Dict[str, List[int]],
    sent_embs: np.ndarray,
    AW: sparse.spmatrix,
    device: torch.device,
    epochs: int = EPOCHS,
    batch_size: int = BATCH_SIZE,
    lr: float = LR,
    lambda_max: float = LAMBDA_MAX,
):
    node_indices = list(range(module.N))
    dataset = NodeDataset(node_indices, AW=AW, batch_neigh_k=BATCH_NEIGH_K)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, collate_fn=collate_batch)
    opt = torch.optim.AdamW(module.parameters(), lr=lr)
    module.train()
    for ep in range(epochs):
        start = time.time()
        lam = lambda_fac_schedule(ep, lambda_max, warmup_epochs=WARMUP_EPOCHS)
        total_loss = 0.0
        for batch_nodes, batch_neigh in loader:
            batch_nodes = batch_nodes.to(device)
            # build context vectors (average sentence embeddings for node)
            ctx_list = []
            sent_dim = sent_embs.shape[1]
            for nid in batch_nodes.cpu().numpy().tolist():
                sidx = tmap.get(str(nid), []) or tmap.get(str(nid), [])
                if sidx and len(sidx) > 0:
                    vec = sent_embs[sidx].mean(axis=0)
                else:
                    vec = np.zeros(sent_dim, dtype=np.float32)
                ctx_list.append(vec)
            context_vecs = torch.tensor(np.stack(ctx_list, axis=0), dtype=torch.float32, device=device)
            out = module(batch_nodes, context_vecs=context_vecs, neighbor_indices=batch_neigh)
            causal_loss = compute_proxy_causal_loss(module, batch_nodes, batch_neigh, AW, device)
            l1 = module.l1_gate_cost()
            loss = causal_loss + lam * l1
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += float(loss.detach().cpu().numpy())
        avg_loss = total_loss / (len(loader) + 1e-12)
        elapsed = time.time() - start
        print(f"[Epoch {ep+1}/{epochs}] avg_loss={avg_loss:.6f} lam={lam:.6e} time={elapsed:.1f}s")

# -------------------------
# Main run - loads from DATA_DIR
# -------------------------
def run_all():
    assert os.path.isdir(DATA_DIR), f"Data dir '{DATA_DIR}' not found."
    assert os.path.exists(NODES_JSONL), f"{NODES_JSONL} missing."
    assert os.path.exists(TMAP_JSON), f"{TMAP_JSON} missing."
    assert os.path.exists(SENT_EMBS_NPY), f"{SENT_EMBS_NPY} missing."
    assert os.path.exists(AW_NPZ), f"{AW_NPZ} missing."

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    nodes = load_jsonl(NODES_JSONL)
    tmap = load_tmap(TMAP_JSON)
    sent_embs = load_sent_embeddings(SENT_EMBS_NPY)
    AW = load_adj_npz(AW_NPZ)

    N = len(nodes)
    module = MultiFacetModule(num_nodes=N, M=M_FACETS, d_model=D_MODEL, device=device, use_context_gate=True, ctx_dim=sent_embs.shape[1])
    module.to(device)
    module.update_facets_via_clustering(nodes=nodes, Tmap=tmap, sent_embeddings=sent_embs, per_node_min_context=1)

    train_loop(
        module=module,
        nodes=nodes,
        tmap=tmap,
        sent_embs=sent_embs,
        AW=AW,
        device=device,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        lr=LR,
        lambda_max=LAMBDA_MAX
    )

if __name__ == "__main__":
    run_all()


##### Multi-Relational Edges

In [None]:
import os, time, math, json
from typing import List, Tuple, Optional, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import sparse
from sklearn.neighbors import NearestNeighbors


DATA_DIR = "./data"
NODES_JSONL = os.path.join(DATA_DIR, "entities.jsonl")
TMAP_JSON = os.path.join(DATA_DIR, "tmap.json")
SENT_EMBS_NPY = os.path.join(DATA_DIR, "sent_embs.npy")
AW_NPZ = os.path.join(DATA_DIR, "AW.npz")   # base co-occurrence adjacency (sparse)
# Multi-relation config
RELATION_NAMES = ["causes", "inhibits", "related_to", "cooccurs"]  
DRANK = 128
K_ANN = 50           # Top-K for Sr via ANN
K_LEARN = 32         # per-node learnable extra neighbors (compact Slearned)
D_MODEL = 256        # must match your MultiFacetModule d_model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# regularization & training params
UR_VR_DROPOUT = 0.1
UR_VR_WEIGHT_DECAY = 1e-5
SLEARN_L1_LAMBDA = 1e-3
UR_VR_L2_LAMBDA = 1e-4
EDGE_NEG_SAMPLE_RATIO = 1    # #neg per positive in minibatch
BATCH_SIZE = 512
EPOCHS = 20
LR = 3e-4

# -------------------------
# Utilities (loaders)
# -------------------------
def load_jsonl(path: str):
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for l in f:
            if l.strip():
                out.append(json.loads(l))
    return out

def load_tmap(path: str):
    import json
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def load_sent_embs(path: str):
    return np.load(path)

def load_adj_npz(path: str):
    return sparse.load_npz(path)

# -------------------------
# Build Sr heuristic mask (sparse) from AW (cooccurrence) + ANN
# returns Sr_as_set: set of (i,j) tuples considered plausible
# -------------------------
def build_Sr_from_AW_and_ann(AW_sparse: sparse.spmatrix, node_seed_embeddings: np.ndarray, k_ann: int = K_ANN):
    N = AW_sparse.shape[0]
    Sr_pairs = set()
    # 1) text cooccurrence from AW (nonzero entries)
    coo = AW_sparse.tocoo()
    for i, j in zip(coo.row.tolist(), coo.col.tolist()):
        Sr_pairs.add((int(i), int(j)))
    # 2) ANN top-k on seed embeddings
    nbrs = NearestNeighbors(n_neighbors=min(k_ann, node_seed_embeddings.shape[0]-1), algorithm="auto", metric="cosine").fit(node_seed_embeddings)
    distances, indices = nbrs.kneighbors(node_seed_embeddings, return_distance=True)
    for i in range(node_seed_embeddings.shape[0]):
        for j in indices[i]:
            if i == j: 
                continue
            Sr_pairs.add((int(i), int(j)))
    return Sr_pairs

# -------------------------
# RelationFactorizer class
# -------------------------
class RelationFactorizer(nn.Module):
    def __init__(self, num_nodes: int, relation_names: List[str], drank: int = DRANK, device: torch.device = DEVICE, k_learn: int = K_LEARN, Sr_pairs: Optional[set] = None):
        super().__init__()
        self.N = num_nodes
        self.R = len(relation_names)
        self.rel_names = relation_names
        self.drank = drank
        self.device = device
        self.k_learn = k_learn

        # Ur and Vr per relation: implemented as Embedding (N x drank)
        self.U_tables = nn.ModuleDict({r: nn.Embedding(self.N, self.drank) for r in relation_names})
        self.V_tables = nn.ModuleDict({r: nn.Embedding(self.N, self.drank) for r in relation_names})
        for r in relation_names:
            nn.init.normal_(self.U_tables[r].weight, std=0.02)
            nn.init.normal_(self.V_tables[r].weight, std=0.02)

        # dropout on Ur/Vr during training
        self.urvr_dropout = nn.Dropout(UR_VR_DROPOUT)

        # initialize learned_candidate_idx as -1 for padding if insufficient candidates.
        # learned_idx: (N, K_learn) ints on CPU (not Parameters) to be used to index neighbor nodes
        self.learned_idx = torch.full((self.N, self.k_learn), -1, dtype=torch.long)  # filled later with ints (on cpu)
        # logits for those candidate edges (trainable)
        self.learned_logits = nn.Parameter(torch.zeros(self.N, self.k_learn))  # learnable logits
        # mask for which learned slots contain valid neighbors (1 if valid index >=0)
        self.register_buffer("_learned_valid_mask", torch.zeros(self.N, self.k_learn, dtype=torch.bool))

        # Sr (heuristic) stored as python set of tuples for quick membership checks
        self.Sr_pairs = Sr_pairs if Sr_pairs is not None else set()

    def to_device(self):
        self.to(self.device)

    def set_learned_candidates(self, learned_idx_np: np.ndarray):
        """Provide per-node learned candidate indices as numpy int array shape (N, K_learn),
           invalid entries must be -1. This stores indices on CPU buffer and sets valid mask."""
        assert learned_idx_np.shape == (self.N, self.k_learn)
        self.learned_idx = torch.from_numpy(learned_idx_np.astype(np.int64))
        valid = (self.learned_idx >= 0)
        self._learned_valid_mask = valid
        # move learned_idx to cpu buffer and keep on CPU to index; during forward we'll move to device
        # keep as attribute but not Parameter
        # ensure learned_idx on cpu
        if self.learned_idx.device != torch.device("cpu"):
            self.learned_idx = self.learned_idx.cpu()

    def get_edge_score(self, rel_name: str, src_idx: torch.LongTensor, dst_idx: torch.LongTensor) -> torch.Tensor:
        """Vectorized: src_idx, dst_idx are 1D LongTensors same len L -> returns logits (L,)"""
        # fetch rows
        U = self.U_tables[rel_name](src_idx.to(self.device))
        V = self.V_tables[rel_name](dst_idx.to(self.device))
        U = self.urvr_dropout(U)
        V = self.urvr_dropout(V)
        logits = (U * V).sum(dim=-1) / math.sqrt(self.drank)
        return logits

    def compute_prob_for_batch_pairs(self, rel_name: str, pair_list: List[Tuple[int,int]]) -> torch.Tensor:
        """pair_list: list of (i,j) on CPU ints -> returns probs tensor (len,)"""
        if len(pair_list) == 0:
            return torch.tensor([], device=self.device)
        src = torch.tensor([p[0] for p in pair_list], dtype=torch.long, device=self.device)
        dst = torch.tensor([p[1] for p in pair_list], dtype=torch.long, device=self.device)
        logits = self.get_edge_score(rel_name, src, dst)
        probs = torch.sigmoid(logits)
        # apply Shybrid mask (if pair not in Sr nor in learned -> zero)
        mask = torch.tensor([1 if ((int(p[0]), int(p[1])) in self.Sr_pairs) else 0 for p in pair_list], dtype=torch.bool, device=self.device)
        # to check learned candidates
        if hasattr(self, "learned_idx"):
            learned_idx_dev = self.learned_idx.to(self.device)
            # for each src, check equality across k
            src_expand = src.unsqueeze(1).expand(-1, self.k_learn)  # (L, K)
            # gather the learned neighbors for those src
            learned_neighbors_for_src = learned_idx_dev[src]  # (L, K)
            # compare dst to each col
            matches = (learned_neighbors_for_src == dst.unsqueeze(1))
            learned_mask = matches.any(dim=1)  # (L,)
            mask = mask | learned_mask
        probs = probs * mask.float()
        return probs

    def shybrid_mask_for_pairs(self, pair_list: List[Tuple[int,int]]) -> torch.Tensor:
        """Return mask (0/1) tensor for given pairs according to Sr ∪ Slearned"""
        L = len(pair_list)
        mask = torch.zeros(L, dtype=torch.bool, device=self.device)
        for idx, (i,j) in enumerate(pair_list):
            if (i,j) in self.Sr_pairs:
                mask[idx] = True
        if hasattr(self, "learned_idx"):
            learned_idx_dev = self.learned_idx.to(self.device)
            src = torch.tensor([p[0] for p in pair_list], dtype=torch.long, device=self.device)
            dst = torch.tensor([p[1] for p in pair_list], dtype=torch.long, device=self.device)
            learned_neighbors_for_src = learned_idx_dev[src]  # (L, K)
            matches = (learned_neighbors_for_src == dst.unsqueeze(1))
            learned_mask = matches.any(dim=1)
            mask = mask | learned_mask
        return mask.float()

    def forward_edge_logits(self, rel_name: str, pair_src: torch.LongTensor, pair_dst: torch.LongTensor) -> torch.Tensor:
        """Compute raw logits for given batches (tensor inputs). pair_src/dst on device."""
        U = self.U_tables[rel_name](pair_src)
        V = self.V_tables[rel_name](pair_dst)
        logits = (U * V).sum(dim=-1) / math.sqrt(self.drank)

        mask = torch.tensor([1 if ((int(s.cpu().item()), int(d.cpu().item())) in self.Sr_pairs) else 0 for s,d in zip(pair_src, pair_dst)], dtype=torch.bool, device=self.device)
        if hasattr(self, "learned_idx"):
            learned_idx_dev = self.learned_idx.to(self.device)
            learned_neighbors_for_src = learned_idx_dev[pair_src]
            dst_expand = pair_dst.unsqueeze(1)
            learned_matches = (learned_neighbors_for_src == dst_expand).any(dim=1)
            mask = mask | learned_matches
        logits = logits * mask.float() + (1.0 - mask.float()) * (-50.0)
        return logits

    def learned_sparsity_penalty(self):
        """L1 penalty on positive/activated learned logits to discourage adding many corrections."""
        gate = torch.sigmoid(self.learned_logits)
        return gate.sum() / (self.N * max(1, self.k_learn))

    def ur_vr_l2(self):
        s = 0.0
        for r in self.rel_names:
            s = s + (self.U_tables[r].weight.norm(p=2) ** 2) + (self.V_tables[r].weight.norm(p=2) ** 2)
        return s

# -------------------------
# Integration + joint training loop 
# -------------------------
def joint_train_relation_and_facets(multi_facet_module, relation_factorizer, nodes, tmap, sent_embs, AW_sparse, device=DEVICE,
                                   epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR):
    N = len(nodes)
    all_node_ids = np.arange(N, dtype=np.int64)
    coo = AW_sparse.tocoo()
    pos_pairs = list(zip(coo.row.tolist(), coo.col.tolist()))
    rng = np.random.RandomState(0)

    # optimizer for both modules
    params = list(multi_facet_module.parameters()) + list(relation_factorizer.parameters())
    opt = torch.optim.AdamW(params, lr=lr, weight_decay=UR_VR_WEIGHT_DECAY)

    def get_H_active_for_nodes(node_list: List[int]):
        out_list = []
        chunk = 2048
        for i in range(0, len(node_list), chunk):
            sub = node_list[i:i+chunk]
            sub_t = torch.tensor(sub, dtype=torch.long, device=device)
            ctx_list = []
            sent_dim = sent_embs.shape[1]
            for nid in sub:
                sidx = tmap.get(str(nid), []) or tmap.get(str(nid), [])
                if sidx and len(sidx) > 0:
                    vec = sent_embs[sidx].mean(axis=0)
                else:
                    vec = np.zeros(sent_dim, dtype=np.float32)
                ctx_list.append(vec)
            context_vecs = torch.tensor(np.stack(ctx_list, axis=0), dtype=torch.float32, device=device)
            with torch.set_grad_enabled(True):
                out = multi_facet_module(sub_t, context_vecs=context_vecs, neighbor_indices=None)
            out_list.append(out["H_active"].detach())  
        return torch.cat(out_list, dim=0)

    # training iterations
    multi_facet_module.train()
    relation_factorizer.train()
    P = len(pos_pairs)
    for ep in range(epochs):
        t0 = time.time()
        perm = rng.permutation(P)
        batch_losses = []
        for start in range(0, P, batch_size):
            end = min(start + batch_size, P)
            batch_pos = [pos_pairs[perm[i]] for i in range(start, end)]
            # negatives
            batch_neg = []
            for (a,b) in batch_pos:
                for _ in range(EDGE_NEG_SAMPLE_RATIO):
                    neg_j = int(rng.randint(0, N))
                    batch_neg.append((a, neg_j))
            rel_target = "cooccurs" if "cooccurs" in relation_factorizer.rel_names else relation_factorizer.rel_names[0]
            pair_list = batch_pos + batch_neg
            labels = torch.tensor([1.0]*len(batch_pos) + [0.0]*len(batch_neg), dtype=torch.float32, device=device)

            src_tensor = torch.tensor([p[0] for p in pair_list], dtype=torch.long, device=device)
            dst_tensor = torch.tensor([p[1] for p in pair_list], dtype=torch.long, device=device)

            # compute logits from Ur/Vr
            logits = relation_factorizer.forward_edge_logits(rel_target, src_tensor, dst_tensor)  # (L,)
            prob = torch.sigmoid(logits)

            unique_nodes = np.unique(np.concatenate([np.array([p[0] for p in pair_list]), np.array([p[1] for p in pair_list])]))
            H_map = {}
            chunk_sz = 2048
            for i0 in range(0, len(unique_nodes), chunk_sz):
                chunk_nodes = unique_nodes[i0:i0+chunk_sz].tolist()
                sub_t = torch.tensor(chunk_nodes, dtype=torch.long, device=device)
                ctx_list = []
                sent_dim = sent_embs.shape[1]
                for nid in chunk_nodes:
                    sidx = tmap.get(str(nid), []) or tmap.get(str(nid), [])
                    if sidx and len(sidx) > 0:
                        vec = sent_embs[sidx].mean(axis=0)
                    else:
                        vec = np.zeros(sent_dim, dtype=np.float32)
                    ctx_list.append(vec)
                context_vecs = torch.tensor(np.stack(ctx_list, axis=0), dtype=torch.float32, device=device)
                out = multi_facet_module(sub_t, context_vecs=context_vecs, neighbor_indices=None)
                H_out = out["H_active"]  # (chunk, d)
                for idx_local, nid in enumerate(chunk_nodes):
                    H_map[int(nid)] = H_out[idx_local]

            if not hasattr(relation_factorizer, "ur_to_h_linear"):
                relation_factorizer.ur_to_h_linear = nn.Linear(relation_factorizer.drank, D_MODEL).to(device)
            # get Ur for unique src nodes
            src_unique = torch.tensor([int(x) for x in unique_nodes], dtype=torch.long, device=device)
            # map Ur -> d_model
            Ur_batch = relation_factorizer.U_tables[rel_target](src_unique)  # (U, drank)
            projected = relation_factorizer.ur_to_h_linear(Ur_batch)  # (U, d_model)
            # build target H matrix
            H_target = torch.stack([H_map[int(x)] for x in unique_nodes], dim=0)  # (U, d_model)
            align_loss = F.mse_loss(projected, H_target)

            # main BCE loss for relation
            bce_loss = F.binary_cross_entropy_with_logits(logits, labels)

            # regularizers
            l1_slearn = relation_factorizer.learned_sparsity_penalty() * SLEARN_L1_LAMBDA
            l2_uv = relation_factorizer.ur_vr_l2() * UR_VR_L2_LAMBDA

            loss = bce_loss + 0.1 * align_loss + l1_slearn + l2_uv

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_losses.append(float(loss.detach().cpu().numpy()))

        avg_loss = float(np.mean(batch_losses)) if len(batch_losses) > 0 else 0.0
        t1 = time.time()
        print(f"[Epoch {ep+1}/{epochs}] avg_loss={avg_loss:.6f} time={(t1-t0):.1f}s")

# -------------------------
# run main 
# -------------------------
def run_relation_module_with_facets():
    assert os.path.isdir(DATA_DIR), f"{DATA_DIR} missing"
    assert os.path.exists(NODES_JSONL)
    assert os.path.exists(TMAP_JSON)
    assert os.path.exists(SENT_EMBS_NPY)
    assert os.path.exists(AW_NPZ)

    nodes = load_jsonl(NODES_JSONL)
    tmap = load_tmap(TMAP_JSON)
    sent_embs = load_sent_embs(SENT_EMBS_NPY)
    AW = load_adj_npz(AW_NPZ)
    N = len(nodes)

    multi_facet_module = MultiFacetModule(num_nodes=N, M=5, d_model=D_MODEL, device=DEVICE, use_context_gate=True, ctx_dim=sent_embs.shape[1])
    multi_facet_module.to(DEVICE)
    multi_facet_module.update_facets_via_clustering(nodes=nodes, Tmap=tmap, sent_embeddings=sent_embs, per_node_min_context=1)

    node_seed_emb = np.zeros((N, sent_embs.shape[1]), dtype=np.float32)
    for node in nodes:
        nid = node.get("id")
        if isinstance(nid, str) and nid.startswith("N"):
            try:
                idx = int(nid[1:])
            except:
                idx = int(nid)
        else:
            idx = int(nid)
        sidx = tmap.get(str(nid), []) or tmap.get(str(idx), [])
        if sidx and len(sidx) > 0:
            node_seed_emb[idx] = sent_embs[sidx].mean(axis=0)
        else:
            node_seed_emb[idx] = np.zeros(sent_embs.shape[1], dtype=np.float32)

    Sr_pairs = build_Sr_from_AW_and_ann(AW, node_seed_emb, K_ANN)

    relation_factorizer = RelationFactorizer(num_nodes=N, relation_names=RELATION_NAMES, drank=DRANK, device=DEVICE, k_learn=K_LEARN, Sr_pairs=Sr_pairs)
    relation_factorizer.to(DEVICE)

    nbrs = NearestNeighbors(n_neighbors=min(2*K_LEARN+5, node_seed_emb.shape[0]-1), algorithm="auto", metric="cosine").fit(node_seed_emb)
    _, knn_idx = nbrs.kneighbors(node_seed_emb, return_distance=True)
    learned_idx = np.full((N, K_LEARN), -1, dtype=np.int64)
    for i in range(N):
        picks = []
        for cand in knn_idx[i]:
            if i == cand: 
                continue
            if (i, int(cand)) in Sr_pairs:
                continue
            picks.append(int(cand))
            if len(picks) >= K_LEARN:
                break
        for k in range(K_LEARN):
            learned_idx[i, k] = picks[k] if k < len(picks) else -1
    relation_factorizer.set_learned_candidates(learned_idx)

    # run joint training
    joint_train_relation_and_facets(multi_facet_module, relation_factorizer, nodes, tmap, sent_embs, AW, device=DEVICE, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR)

if __name__ == "__main__":
    run_relation_module_with_facets()


### 3.4.2.1 Hierarchical Architecture (Addressing the scalability issues)

#### Phase 1: Differentiable Graph Coarsening

In [None]:
import os
import json
import time
from typing import List, Optional

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import from_scipy_sparse_matrix

# -------------------------
# CONFIG 
# -------------------------
DATA_DIR = "./data"
NODES_JSONL = os.path.join(DATA_DIR, "entities.jsonl")
TMAP_JSON = os.path.join(DATA_DIR, "tmap.json")
SENT_EMBS_NPY = os.path.join(DATA_DIR, "sent_embs.npy")
AW_NPZ = os.path.join(DATA_DIR, "AW.npz")
RELATIONS_DIR = os.path.join(DATA_DIR, "relations")  
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# hyperparams
K_COMM = 256         
BATCH_EPOCHS = 30
LR = 1e-3
RECON_MASK_TOPK = 1000000  # cap for masked reconstruction entries to compute loss on
ENTROPY_COEFF = 1e-3   
PRINT_EVERY = 1

# -------------------------
# I/O helpers
# -------------------------
def load_jsonl(path):
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                out.append(json.loads(line))
    return out

def load_tmap(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def load_sent_embeddings(path):
    return np.load(path)

def load_sparse_npz(path):
    return sp.load_npz(path).tocsr()

def load_relations(rel_dir):
    if not os.path.isdir(rel_dir):
        return []
    files = sorted([f for f in os.listdir(rel_dir) if f.endswith(".npz")])
    mats = []
    for fn in files:
        mats.append(load_sparse_npz(os.path.join(rel_dir, fn)))
    return mats

# -------------------------
# Utilities: sparse -> torch
# -------------------------
def scipy_csr_to_edge_index_and_weight(A_csr):
    A_coo = A_csr.tocoo()
    row = torch.tensor(A_coo.row, dtype=torch.long)
    col = torch.tensor(A_coo.col, dtype=torch.long)
    edge_index = torch.stack([row, col], dim=0)
    edge_weight = torch.tensor(A_coo.data, dtype=torch.float32)
    return edge_index, edge_weight

# -------------------------
# GNNCluster (2-layer GCN -> K logits)
# -------------------------
class GNNCluster(nn.Module):
    def __init__(self, in_dim, hid_dim, K):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim)
        self.lin = nn.Linear(hid_dim, K)

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        logits = self.lin(x)
        return logits  # (N, K)

# -------------------------
# Differentiable Coarsening wrapper
# -------------------------
class DifferentiableCoarsening(nn.Module):
    def __init__(self, in_dim, hid_dim, K, device):
        super().__init__()
        self.device = device
        self.gnncluster = GNNCluster(in_dim, hid_dim, K).to(device)

    def forward(self, Z, A_edge_index, A_edge_weight):
        logits = self.gnncluster(Z, A_edge_index, A_edge_weight)  # (N, K)
        C = F.softmax(logits, dim=-1)                            # row-wise soft assignment
        return C, logits

    def compute_coarse_embeddings(self, C, Z):
        H_coarse = C.t() @ Z   # (K, d)
        return H_coarse

    def compute_Acoarse_from_sparse(self, C, A_scipy):
        C_torch = C  # (N, K)
        AC = torch.tensor((A_scipy @ C_torch.detach().cpu().numpy()), dtype=torch.float32, device=self.device)  # (N, K)
        Acoarse = C_torch.t() @ AC  # (K, K) on device
        return Acoarse

    def reconstruct_fine_from_coarse(self, C, Acoarse):
        return C @ (Acoarse @ C.t())  # (N,N) dense -> careful for memory (we will only sample entries later)

# -------------------------
# Helper: sample mask indices for masked reconstruction loss
# -------------------------
def sample_recon_indices(A_scipy, max_samples):
    # returns arrays (rows, cols, vals) sampled from A (positives) and random negatives
    nz_rows, nz_cols = A_scipy.nonzero()
    nnz = len(nz_rows)
    if nnz == 0:
        return np.array([], dtype=np.int64), np.array([], dtype=np.int64), np.array([], dtype=np.float32)
    idxs = np.arange(nnz)
    if nnz > max_samples // 2:
        chosen = np.random.choice(idxs, size=max_samples//2, replace=False)
    else:
        chosen = idxs
    pos_r = nz_rows[chosen]; pos_c = nz_cols[chosen]; pos_v = A_scipy.data[chosen]
    # sample equal number of negative pairs uniformly (ensure not in nz)
    neg_count = len(pos_r)
    N = A_scipy.shape[0]
    neg_r = np.random.randint(0, N, size=neg_count)
    neg_c = np.random.randint(0, N, size=neg_count)
    # filter negatives that accidentally are positives
    mask = A_scipy[neg_r, neg_c].A1 == 0
    # if few pass, expand until we have enough (bounded)
    attempts = 0
    while mask.sum() < neg_count and attempts < 5:
        more_r = np.random.randint(0, N, size=neg_count)
        more_c = np.random.randint(0, N, size=neg_count)
        combined_mask = A_scipy[more_r, more_c].A1 == 0
        sel = combined_mask.nonzero()[0]
        if len(sel) > 0:
            take = min(neg_count - mask.sum(), len(sel))
            insert_idx = np.where(mask == False)[0][:take]
            neg_r[insert_idx] = more_r[sel[:take]]
            neg_c[insert_idx] = more_c[sel[:take]]
            mask = A_scipy[neg_r, neg_c].A1 == 0
        attempts += 1
    neg_v = np.zeros_like(neg_r, dtype=np.float32)
    # concatenate pos and neg
    rows = np.concatenate([pos_r, neg_r])
    cols = np.concatenate([pos_c, neg_c])
    vals = np.concatenate([pos_v.astype(np.float32), neg_v])
    return rows, cols, vals

# -------------------------
# Main run for Phase 1
# -------------------------
def run_phase1():
    assert os.path.isdir(DATA_DIR), f"{DATA_DIR} missing"
    assert os.path.exists(NODES_JSONL), f"{NODES_JSONL} missing"
    assert os.path.exists(TMAP_JSON), f"{TMAP_JSON} missing"
    assert os.path.exists(SENT_EMBS_NPY), f"{SENT_EMBS_NPY} missing"
    assert os.path.exists(AW_NPZ), f"{AW_NPZ} missing"

    device = DEVICE
    nodes = load_jsonl(NODES_JSONL)
    tmap = load_tmap(TMAP_JSON)
    sent_embs = load_sent_embeddings(SENT_EMBS_NPY)   # (num_sentences, dim)
    AW = load_sparse_npz(AW_NPZ)                      # coarse co-occurrence (one relation fallback)
    relations = load_relations(RELATIONS_DIR)
    if len(relations) == 0:
        relations = [AW]  # single relation fallback

    N = len(nodes)
    sent_dim = sent_embs.shape[1]

    # Build node-level embeddings Z by averaging sentence embeddings per node as explained in spec
    Z_np = np.zeros((N, sent_dim), dtype=np.float32)
    for node in nodes:
        nid = node.get("id")
        if isinstance(nid, str) and nid.startswith("N"):
            try:
                nidx = int(nid[1:])
            except:
                nidx = int(nid)
        else:
            nidx = int(nid)
        sidx = tmap.get(str(nid), []) or tmap.get(str(nidx), [])
        if sidx and len(sidx) > 0:
            Z_np[nidx] = sent_embs[sidx].mean(axis=0)
        else:
            Z_np[nidx] = np.random.normal(scale=0.01, size=(sent_dim,))
    Z = torch.tensor(Z_np, dtype=torch.float32, device=device)

    # build edge_index / edge_weight for GCN from AW (use AW as structural connectivity)
    edge_index, edge_weight = scipy_csr_to_edge_index_and_weight(AW)
    edge_index = edge_index.to(device); edge_weight = edge_weight.to(device)

    model = DifferentiableCoarsening(in_dim=sent_dim, hid_dim=GCN_HID, K=K_COMM, device=device).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)

    Shybrid_path = os.path.join(DATA_DIR, "Shybrid.npz")
    Shybrid = None
    if os.path.exists(Shybrid_path):
        Shybrid = load_sparse_npz(Shybrid_path)
    for epoch in range(BATCH_EPOCHS):
        t0 = time.time()
        model.train()
        opt.zero_grad()
        C, logits = model(Z, edge_index, edge_weight)
        entropy = -(C * torch.log(C + 1e-12)).sum(dim=-1).mean()
        ent_loss = ENTROPY_COEFF * entropy

        recon_loss_total = 0.0
        for r_idx, Ar in enumerate(relations):
            # compute coarse adjacency
            Acoarse = model.compute_Acoarse_from_sparse(C, Ar)  # (K,K)
            rows, cols, vals = sample_recon_indices(Ar, max_samples=RECON_MASK_TOPK)
            if rows.size == 0:
                continue
            V = (Acoarse @ C.t())  # (K, N)
            Cr = C[rows]         # (S, K)
            Vc = V[:, cols].t()  # (S, K)
            A_hat_vals = (Cr * Vc).sum(dim=1)  # (S,)
            target = torch.tensor(vals, dtype=torch.float32, device=device)
            recon_loss = F.mse_loss(A_hat_vals, target)
            recon_loss_total = recon_loss_total + recon_loss

        loss = recon_loss_total + ent_loss
        loss.backward()
        opt.step()

        if (epoch + 1) % PRINT_EVERY == 0 or epoch == 0:
            model.eval()
            with torch.no_grad():
                C_eval, _ = model(Z, edge_index, edge_weight)
                H_coarse = model.compute_coarse_embeddings(C_eval, Z)  # (K, d)
                print(f"[Epoch {epoch+1}/{BATCH_EPOCHS}] loss={loss.item():.6f} recon={recon_loss_total.item() if hasattr(recon_loss_total,'item') else recon_loss_total:.6f} entropy={entropy.item():.6f} time={time.time()-t0:.1f}s")
                # save assignments and coarse graph outputs
                out_dir = os.path.join(DATA_DIR, "coarsening_out")
                os.makedirs(out_dir, exist_ok=True)
                np.save(os.path.join(out_dir, "C.npy"), C_eval.detach().cpu().numpy())
                np.save(os.path.join(out_dir, "H_coarse.npy"), H_coarse.detach().cpu().numpy())
                # save Acoarse for each relation
                for r_idx, Ar in enumerate(relations):
                    Acoarse = model.compute_Acoarse_from_sparse(C_eval, Ar)
                    np.save(os.path.join(out_dir, f"Acoarse_r{r_idx}.npy"), Acoarse.detach().cpu().numpy())

    print("Phase-1 coarsening finished. Outputs saved under ./data/coarsening_out")

if __name__ == "__main__":
    run_phase1()


#### Phase 2: The Coarse-Grained Model (Learning "Highways")

In [None]:
import os
import json
import time
from typing import List, Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

# -------------------------
# CONFIG 
# -------------------------
DATA_DIR = "./data"
COARSE_ASSIGN_NPY = os.path.join(DATA_DIR, "C_assign.npy")    # shape (N, C) soft assignments
PCOARSE_JSONL = os.path.join(DATA_DIR, "Pcoarse.jsonl")
OUT_DIR = os.path.join(DATA_DIR, "coarse_ebm_out")
os.makedirs(OUT_DIR, exist_ok=True)

# Hyperparameters
D_MODEL = 256
EPOCHS = 200
LR = 3e-4
BATCH_SIZE = 1  # coarse model is small; training is effectively on whole coarse graph
LAMBDA_DAG = 1.0     # weight for acyclicity penalty 
LAMBDA_MSE = 1.0     # weight for prior imitation loss
WEIGHT_DECAY = 1e-5

# -------------------------
# Utilities to load coarsened prior
# -------------------------
def load_coarse_assign(path: str) -> np.ndarray:
    return np.load(path)  # shape (N, C)

def load_pcoarse(path: str) -> List[dict]:
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            out.append(json.loads(line))
    return out

def build_coarse_targets(pcoarse: List[dict], num_communities: int) -> Tuple[np.ndarray, Dict[str,int]]:
    # Collect unique relations and build mapping
    rel_set = []
    for rec in pcoarse:
        r = rec["relation"]
        if r not in rel_set:
            rel_set.append(r)
    rel_map = {r:i for i,r in enumerate(rel_set)}
    R = len(rel_set)
    target_A = np.zeros((R, num_communities, num_communities), dtype=np.float32)
    # Also build mask to indicate which pairs exist in prior (for weighted loss)
    mask = np.zeros_like(target_A, dtype=np.float32)
    for rec in pcoarse:
        ca = int(rec["ca"])
        cb = int(rec["cb"])
        r = rec["relation"]
        score = float(rec.get("score", 1.0))
        ridx = rel_map[r]
        # aggregate by taking max (keeps strong signals); you could average weighted by counts
        target_A[ridx, ca, cb] = max(target_A[ridx, ca, cb], score)
        mask[ridx, ca, cb] = 1.0
    return target_A, mask, rel_map

# -------------------------
# Coarse EBM Module
# -------------------------
class CoarseEBM(nn.Module):
    def __init__(self, C: int, R: int, d_model: int):
        super().__init__()
        self.C = C
        self.R = R
        self.d = d_model
        # Coarse community embeddings
        self.H_coarse = nn.Parameter(torch.randn(C, d_model) * 0.02)
        # Dense adjacency logits for each relation (R, C, C)
        self.A_logits = nn.Parameter(torch.randn(R, C, C) * 0.02)
        # small MLP readout to produce features for any downstream use
        self.readout = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(inplace=True),
            nn.LayerNorm(d_model)
        )

    def forward(self):
        A_prob = torch.sigmoid(self.A_logits)  # (R, C, C)
        H = self.readout(self.H_coarse)        # (C, d)
        return A_prob, H

    def acyclicity_penalty(self, Ar: Tensor) -> Tensor:
        # h(Ar) = tr(exp(Ar ⊙ Ar)) - C
        sq = Ar * Ar
        # matrix exponential
        expm = torch.matrix_exp(sq)
        tr = torch.trace(expm)
        return tr - float(self.C)

# -------------------------
# Training routine
# -------------------------
def train_coarse_ebm(
    module: CoarseEBM,
    target_A: np.ndarray,
    mask: np.ndarray,
    causal_relation_names: List[str],
    rel_map: Dict[str,int],
    epochs: int = EPOCHS,
    lr: float = LR,
    lambda_dag: float = LAMBDA_DAG,
    lambda_mse: float = LAMBDA_MSE,
):
    device = next(module.parameters()).device
    target_t = torch.tensor(target_A, dtype=torch.float32, device=device)  # (R,C,C)
    mask_t = torch.tensor(mask, dtype=torch.float32, device=device)
    optimizer = torch.optim.AdamW(module.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
    for ep in range(epochs):
        t0 = time.time()
        module.train()
        optimizer.zero_grad()
        A_prob, H = module()  # A_prob (R,C,C)
        # imitation loss: MSE but only where mask==1 (we want to encourage matching known priors)
        mse_loss = F.mse_loss(A_prob * mask_t, target_t * mask_t, reduction='sum')
        # normalize by number of masked entries to keep scale stable
        denom = mask_t.sum().clamp_min(1.0)
        mse_loss = mse_loss / denom
        dag_pen = torch.tensor(0.0, device=device)
        for rname in causal_relation_names:
            if rname not in rel_map:
                continue
            ridx = rel_map[rname]
            Ar = A_prob[ridx]  # (C,C)
            dag_pen = dag_pen + module.acyclicity_penalty(Ar)
        loss = lambda_mse * mse_loss + lambda_dag * dag_pen
        loss.backward()
        optimizer.step()
        t_epoch = time.time() - t0
        if (ep + 1) % 5 == 0 or ep == 0:
            with torch.no_grad():
                reg_A_mean = A_prob.mean().item()
                print(f"[Epoch {ep+1}/{epochs}] loss={loss.item():.6f} mse={mse_loss.item():.6f} dag={dag_pen.item():.6f} A_mean={reg_A_mean:.4f} time={t_epoch:.2f}s")
    # final save
    module.eval()
    with torch.no_grad():
        A_prob_final, H_final = module()
        torch.save(A_prob_final.cpu(), os.path.join(OUT_DIR, "Acoarse_prob.pt"))
        torch.save(H_final.cpu(), os.path.join(OUT_DIR, "Hcoarse.pt"))
        torch.save(module.state_dict(), os.path.join(OUT_DIR, "coarse_ebm_state.pt"))
    return module

# -------------------------
# Main driver for Phase-2
# -------------------------
def run_phase2():
    assert os.path.exists(COARSE_ASSIGN_NPY), f"{COARSE_ASSIGN_NPY} missing."
    assert os.path.exists(PCOARSE_JSONL), f"{PCOARSE_JSONL} missing."

    C_assign = load_coarse_assign(COARSE_ASSIGN_NPY)  # (N, C)
    N, C = C_assign.shape

    pcoarse = load_pcoarse(PCOARSE_JSONL)
    target_A_np, mask_np, rel_map = build_coarse_targets(pcoarse, num_communities=C)

    R = target_A_np.shape[0]
    print(f"Loaded coarse assign: N={N}, C={C}, relations={R}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CoarseEBM(C=C, R=R, d_model=D_MODEL).to(device)

    # Determine which relation names correspond to causal relations requiring DAG constraint.
    causal_relations = [r for r in rel_map.keys() if "CAUSE" in r.upper() or "DIRECT" in r.upper() or "CAUSES" in r.upper()]
    print("Causal relations (acyclicity applied):", causal_relations)

    trained = train_coarse_ebm(
        module=model,
        target_A=target_A_np,
        mask=mask_np,
        causal_relation_names=causal_relations,
        rel_map=rel_map,
        epochs=EPOCHS,
        lr=LR,
        lambda_dag=LAMBDA_DAG,
        lambda_mse=LAMBDA_MSE,
    )
    print("Phase-2 finished. Outputs in:", OUT_DIR)

if __name__ == "__main__":
    run_phase2()


#### Phase 3: The fine-grained Model (with Adaptive context integration)

In [None]:

import os
import json
import time
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# -------------------------
# Config
# -------------------------
DATA_DIR = "./data"   
Z_NPY = os.path.join(DATA_DIR, "Z.npy")
C_NPY = os.path.join(DATA_DIR, "C.npy")
ACOARSE_NPY = os.path.join(DATA_DIR, "Acoarse.npy")
HCOARSE_NPY = os.path.join(DATA_DIR, "Hcoarse.npy")
PRICH_JSONL = os.path.join(DATA_DIR, "prich.jsonl")  

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
D_MODEL = 256        # should match Hcoarse dim
DZ = 64             
K_COMM = None        # inferred
R_RELS = None        # inferred

BATCH_SIZE = 1024    # 
LR = 1e-4
EPOCHS = 10
WEIGHT_DECAY = 1e-5

# -------------------------
# Basic loaders 
# -------------------------
assert os.path.exists(Z_NPY), f"{Z_NPY} not found"
assert os.path.exists(C_NPY), f"{C_NPY} not found"
assert os.path.exists(ACOARSE_NPY), f"{ACOARSE_NPY} not found"
assert os.path.exists(HCOARSE_NPY), f"{HCOARSE_NPY} not found"
assert os.path.exists(PRICH_JSONL), f"{PRICH_JSONL} not found"

Z = np.load(Z_NPY)                 # (N, dz)
C_mat = np.load(C_NPY)             # (N, K)
Acoarse = np.load(ACOARSE_NPY)     # (K, K, R)
Hcoarse = np.load(HCOARSE_NPY)     # (K, d_model)

N = Z.shape[0]
DZ = Z.shape[1]
K_COMM = C_mat.shape[1]
R_RELS = Acoarse.shape[2]
assert Hcoarse.shape[0] == K_COMM
assert Hcoarse.shape[1] == D_MODEL
assert Acoarse.shape[0] == K_COMM and Acoarse.shape[1] == K_COMM

print(f"Loaded: N={N}, dz={DZ}, K={K_COMM}, R={R_RELS}, d_model={D_MODEL}")

# -------------------------
# Prich dataset (node-level rulebook)
# -------------------------
class PriChDataset(Dataset):
    def __init__(self, jsonl_path: str):
        self.entries = []
        with open(jsonl_path, "r", encoding="utf-8") as f:
            for ln in f:
                j = json.loads(ln)
                # Expect fields: i, j, r, score
                self.entries.append((int(j["i"]), int(j["j"]), int(j["r"]), float(j.get("score", 1.0))))
        if len(self.entries) == 0:
            raise RuntimeError("prich dataset empty")
    def __len__(self):
        return len(self.entries)
    def __getitem__(self, idx):
        return self.entries[idx]

# collate returns batched tensors
def collate_prich(batch):
    i_idx = torch.tensor([b[0] for b in batch], dtype=torch.long)
    j_idx = torch.tensor([b[1] for b in batch], dtype=torch.long)
    r_idx = torch.tensor([b[2] for b in batch], dtype=torch.long)
    score = torch.tensor([b[3] for b in batch], dtype=torch.float32)
    return i_idx, j_idx, r_idx, score

# -------------------------
# FineGrainedPhase3 Module
# -------------------------
class FineGrainedPhase3(nn.Module):
    def __init__(self, dz: int, d_model: int, K: int, R: int, device: torch.device):
        super().__init__()
        self.dz = dz
        self.d_model = d_model
        self.K = K
        self.R = R
        self.device = device

        # attention projection matrices (WQ, WK, WV) operate in d_model space.
        self.WQ = nn.Linear(dz, d_model, bias=False)        # Query from Init(Z_i) -> d_model
        self.WK = nn.Linear(d_model, d_model, bias=False)   # Key from Hcoarse -> d_model
        self.WV = nn.Linear(d_model, d_model, bias=False)   # Value from Hcoarse -> d_model

        self.layer_norm = nn.LayerNorm(d_model)

        # f_compat: small MLP that outputs [0,1] affinity from two primed embeddings
        self.fcompat = nn.Sequential(
            nn.Linear(2 * d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(inplace=True),
            nn.Linear(d_model // 2, 1)
        )

        if dz != d_model:
            self.Zproj = nn.Linear(dz, d_model, bias=False)
        else:
            self.Zproj = None

    def primed_embeddings(self, Z_batch: torch.Tensor, Hcoarse: torch.Tensor) -> torch.Tensor:
        # Z_batch: (B, dz) ; Hcoarse: (K, d_model)
        Q = self.WQ(Z_batch)                          # (B, d_model)
        Kc = self.WK(Hcoarse)                         # (K, d_model)
        Vc = self.WV(Hcoarse)                         # (K, d_model)
        # attention: Q (B, d) x Kc (K, d) -> scores (B, K)
        scores = (Q @ Kc.T) / (self.d_model ** 0.5)   # (B, K)
        att = F.softmax(scores, dim=-1)               # (B, K)
        context = att @ Vc                            # (B, d_model)
        # init projection
        if self.Zproj is not None:
            init_proj = self.Zproj(Z_batch)          # (B, d)
        else:
            # if dz==d_model, cast directly
            init_proj = Z_batch                       # (B, d)
        H0 = self.layer_norm(init_proj + context)     # (B, d)
        return H0

    def compute_coarse_prior_batch(self, C: torch.Tensor, Acoarse: torch.Tensor, i_idxs: torch.Tensor, j_idxs: torch.Tensor, r_idxs: torch.Tensor) -> torch.Tensor:
        Ci = C[i_idxs]         # (B, K)
        Cj = C[j_idxs]         # (B, K)
        # We'll process per distinct r value for efficiency
        batch_size = Ci.shape[0]
        out = torch.empty(batch_size, device=self.device, dtype=torch.float32)
        # unique relations in batch
        r_unique, inv = torch.unique(r_idxs, return_inverse=True)
        for idx_r, r in enumerate(r_unique):
            mask = (r_idxs == r)
            pos = torch.nonzero(mask, as_tuple=False).squeeze(1)
            Ci_sub = Ci[pos]   # (b, K)
            Cj_sub = Cj[pos]   # (b, K)
            A_r = Acoarse[:, :, int(r.item())]  
            tmp = Ci_sub @ A_r    # (b, K)
            vals = (tmp * Cj_sub).sum(dim=-1)    # (b,)
            out[pos] = vals
        return out  # (B,)

    def forward_predict_batch(self, Z_batch: torch.Tensor, idx_i: torch.Tensor, idx_j: torch.Tensor, r_idxs: torch.Tensor,
                              C_tensor: torch.Tensor, Acoarse_tensor: torch.Tensor, Hcoarse_tensor: torch.Tensor) -> torch.Tensor:
        # Z_batch: batch of unique Z vectors corresponding to nodes used in this batch.
        # Bwe fetch Z for i and j individually outside and pass relevant Z_i and Z_j.
        raise NotImplementedError("Use predict_pair_batch convenience below")

    def predict_pair_batch(self,
                           Z_i: torch.Tensor,    # (B, dz)
                           Z_j: torch.Tensor,    # (B, dz)
                           i_idxs: torch.Tensor, # (B,)
                           j_idxs: torch.Tensor, # (B,)
                           r_idxs: torch.Tensor, # (B,)
                           C_tensor: torch.Tensor,   # (N, K)
                           Acoarse_tensor: torch.Tensor, # (K, K, R)
                           Hcoarse_tensor: torch.Tensor   # (K, d)
                           ) -> torch.Tensor:
        B = i_idxs.shape[0]
        # primed embeddings
        H0_i = self.primed_embeddings(Z_i, Hcoarse_tensor)  # (B, d)
        H0_j = self.primed_embeddings(Z_j, Hcoarse_tensor)  # (B, d)
        # fcompat
        pair_cat = torch.cat([H0_i, H0_j], dim=-1)          # (B, 2d)
        compat_logits = self.fcompat(pair_cat).squeeze(-1)  # (B,)
        compat = torch.sigmoid(compat_logits)               # (B,)
        # coarse prior per pair
        coarse_vals = self.compute_coarse_prior_batch(C_tensor, Acoarse_tensor, i_idxs, j_idxs, r_idxs)  # (B,)
        # final
        out = coarse_vals * compat
        return out, compat_logits  # out in [0,1], compat_logits raw

# -------------------------
# Helpers to batch fetch Z for indices
# -------------------------
def tensor_from_numpy(arr: np.ndarray, dtype=torch.float32, device=DEVICE):
    return torch.tensor(arr, dtype=dtype, device=device)

# -------------------------
# Training for Phase-3
# -------------------------
def train_phase3():
    ds = PriChDataset(PRICH_JSONL)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_prich, num_workers=4, pin_memory=True)

    model = FineGrainedPhase3(dz=DZ, d_model=D_MODEL, K=K_COMM, R=R_RELS, device=DEVICE).to(DEVICE)

    # prepare tensors
    Z_t = tensor_from_numpy(Z, dtype=torch.float32, device=DEVICE)              # (N, dz)
    C_t = tensor_from_numpy(C_mat.astype(np.float32), dtype=torch.float32, device=DEVICE)  # (N, K)
    Acoarse_t = tensor_from_numpy(Acoarse.astype(np.float32), dtype=torch.float32, device=DEVICE)  # (K,K,R)
    Hcoarse_t = tensor_from_numpy(Hcoarse.astype(np.float32), dtype=torch.float32, device=DEVICE)  # (K, d)

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    bce_loss = nn.BCEWithLogitsLoss()

    for ep in range(EPOCHS):
        t0 = time.time()
        total_loss = 0.0
        total_samples = 0
        model.train()
        for i_idxs, j_idxs, r_idxs, scores in loader:
            i_idxs = i_idxs.to(DEVICE)
            j_idxs = j_idxs.to(DEVICE)
            r_idxs = r_idxs.to(DEVICE)
            scores = scores.to(DEVICE)  

            Z_i = Z_t[i_idxs]            # (B, dz)
            Z_j = Z_t[j_idxs]            # (B, dz)

            # forward predict
            pred_probs, compat_logits = model.predict_pair_batch(Z_i, Z_j, i_idxs, j_idxs, r_idxs, C_t, Acoarse_t, Hcoarse_t)
            # pred_probs in [0,1], compat_logits raw; we want a training objective that backprops into compat and WQ/WK/WV
            # Use MSE on final prob + BCE on compat (signal is whether local affinity needed)
            # We'll use a composite loss:
            #  - L_prob = MSE(pred_probs, scores)
            #  - L_compat = BCEWithLogits(compat_logits, scores_binary) where scores_binary = scores>0.5
            scores_bin = (scores > 0.5).float()
            L_prob = F.mse_loss(pred_probs, scores)
            L_compat = bce_loss(compat_logits, scores_bin)
            loss = L_prob + 0.5 * L_compat

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += float(loss.item()) * i_idxs.shape[0]
            total_samples += i_idxs.shape[0]

        avg = total_loss / (total_samples + 1e-12)
        elapsed = time.time() - t0
        print(f"[Phase3] Epoch {ep+1}/{EPOCHS} avg_loss={avg:.6e} time={elapsed:.1f}s")

    # Save model checkpoint
    torch.save({
        "model_state": model.state_dict(),
        "meta": {"dz": DZ, "d_model": D_MODEL, "K": K_COMM, "R": R_RELS}
    }, os.path.join(DATA_DIR, "phase3_finegrained_ckpt.pth"))
    print("Saved Phase3 checkpoint")

# -------------------------
# Entry point
# -------------------------
if __name__ == "__main__":
    train_phase3()


In [None]:
## Evaluating the above...

import os
import json
import math
import time
from typing import List, Tuple, Dict, Optional

import numpy as np
import scipy.sparse as sp
from scipy.stats import spearmanr
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    precision_recall_curve,
    precision_score,
    recall_score,
)
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import torch

# ---------- CONFIG ----------
DATA_DIR = "./data"
NODES_JSONL = os.path.join(DATA_DIR, "entities.jsonl")
TMAP_JSON = os.path.join(DATA_DIR, "tmap.json")
SENT_EMBS_NPY = os.path.join(DATA_DIR, "sent_embs.npy")
AW_NPZ = os.path.join(DATA_DIR, "AW.npz")
PRICH_JSONL = os.path.join(DATA_DIR, "Prich.jsonl")        # node-level prior 
PCOARSE_JSON = os.path.join(DATA_DIR, "Pcoarse.json")      # coarse prior 
C_ASSIGN_NPY = os.path.join(DATA_DIR, "C_assign.npy")      # soft assignment C (N x K) 
ACOARSE_PRED_NPY = os.path.join(DATA_DIR, "Acoarse_pred.npy") # predicted coarse adjacency (K,K,R) 
GFINAL_PRED_NPZ = os.path.join(DATA_DIR, "Gfinal_preds.npz")   # predicted edges list or dense structure 
S_HYBRID_NPZ = os.path.join(DATA_DIR, "Shybrid.npz")        # candidate set mask 
OUTPUT_REPORT = os.path.join(DATA_DIR, "evaluation_report.json")

# ---------- UTILITIES ----------
def load_jsonl(path):
    out=[]
    with open(path,"r",encoding="utf-8") as f:
        for line in f:
            if line.strip(): out.append(json.loads(line))
    return out

def load_npz_adj(path):
    m = sp.load_npz(path)
    return m

def safe_load(path, loader):
    if not os.path.exists(path):
        return None
    return loader(path)

# ---------- PARSERS FOR PRIORS ----------
def parse_prich(prich_path: str) -> List[Tuple[int,int,int,float]]:
    data = load_jsonl(prich_path)
    out=[]
    for item in data:
        # tolerate multiple formats
        if "pair" in item:
            a,b = item["pair"]
            r = item.get("relation",0)
            s = float(item.get("score",1.0))
            out.append((int(a if isinstance(a,int) else int(str(a).lstrip("N"))),
                        int(b if isinstance(b,int) else int(str(b).lstrip("N"))),
                        int(r), s))
        elif all(k in item for k in ("i","j")):
            i = int(item["i"]) if isinstance(item["i"],int) else int(str(item["i"]).lstrip("N"))
            j = int(item["j"]) if isinstance(item["j"],int) else int(str(item["j"]).lstrip("N"))
            r = int(item.get("r",0))
            s = float(item.get("score",1.0))
            out.append((i,j,r,s))
    return out

def parse_pcoarse(pcoarse_path: str) -> List[Tuple[int,int,int,float]]:
    try:
        data = json.load(open(pcoarse_path,"r",encoding="utf-8"))
        out=[]
        if isinstance(data, list):
            for d in data:
                ca = int(d.get("ca",d.get("c_a",0)))
                cb = int(d.get("cb",d.get("c_b",0)))
                r = int(d.get("r",0))
                s = float(d.get("score", d.get("sc",1.0)))
                out.append((ca,cb,r,s))
        return out
    except Exception:
        return []

# ---------- METRIC COMPUTATIONS ----------
def rmse_over_S(pred_scores: List[float], target_scores: List[float]) -> float:
    p = np.asarray(pred_scores, dtype=np.float64)
    t = np.asarray(target_scores, dtype=np.float64)
    return float(np.sqrt(np.mean((p - t)**2)))

def precision_recall_at_k(pred_pairs_scores: List[Tuple[Tuple[int,int,int], float]],
                          true_edge_set: Dict[Tuple[int,int,int], float],
                          K: int) -> Tuple[float,float]:
    # Sort by predicted score desc
    sorted_preds = sorted(pred_pairs_scores, key=lambda x: -x[1])[:K]
    tp = 0
    for (i,j,r),_ in sorted_preds:
        if (i,j,r) in true_edge_set and true_edge_set[(i,j,r)]>0:
            tp += 1
    precision = tp / max(K,1)
    recall = tp / max(len([x for x in true_edge_set.values() if x>0]), 1)
    return precision, recall

def auc_and_ap(pred_list: List[float], labels: List[int]) -> Tuple[float,float]:
    if len(set(labels))==1:
        return float("nan"), float("nan")
    try:
        return float(roc_auc_score(labels, pred_list)), float(average_precision_score(labels, pred_list))
    except Exception:
        return float("nan"), float("nan")

def soft_assignment_entropy(C: np.ndarray) -> np.ndarray:
    # C is N x K, rows sum to 1
    eps = 1e-12
    logc = np.log(C + eps)
    ent = -np.sum(C * logc, axis=1)
    return ent

def acyclicity_trace_exponential(A: np.ndarray) -> float:
    from scipy.linalg import expm
    A2 = A * A
    M = expm(A2)
    return float(np.trace(M) - A.shape[0])

# ---------- ASSEMBLY HELPERS ----------
def assemble_Gfinal_from_components(C: np.ndarray,
                                    Acoarse: np.ndarray,
                                    fcompat_fn,
                                    relation_count: int,
                                    nodes_per_community: Optional[List[List[int]]] = None
                                   ) -> Dict[Tuple[int,int,int], float]:
    # Adaptive broadcasting as described: Aijr = Acoarse[k,m,r] * fcompat(Hi,Hj)
    N, K = C.shape
    preds = {}
    topK = min(8, K)  # pragmatic default
    top_indices = np.argsort(-C, axis=1)[:, :topK]  # (N, topK)
    for i in range(N):
        i_tops = top_indices[i]
        for j in range(N):
            j_tops = top_indices[j]
            for r in range(relation_count):
                coarse_score = 0.0
                for k in i_tops:
                    for l in j_tops:
                        coarse_score += C[i,k] * C[j,l] * float(Acoarse[k,l,r])
                # compute local affinity
                local_aff = float(fcompat_fn(i,j))
                score = coarse_score * local_aff
                if score>0:
                    preds[(i,j,r)] = score
    return preds

# ---------- MAIN EVALUATION WORKFLOW ----------
def evaluate_all(data_dir=DATA_DIR, report_path=OUTPUT_REPORT, top_k_eval=100):
    # load basic artifacts
    assert os.path.isdir(data_dir), f"Data dir {data_dir} missing"
    AW = load_npz_adj(AW_NPZ) if os.path.exists(AW_NPZ) else None
    nodes = load_jsonl(NODES_JSONL) if os.path.exists(NODES_JSONL) else None
    prich = parse_prich(PRICH_JSONL) if os.path.exists(PRICH_JSONL) else []
    pcoarse = parse_pcoarse(PCOARSE_JSON) if os.path.exists(PCOARSE_JSON) else []
    C = np.load(C_ASSIGN_NPY) if os.path.exists(C_ASSIGN_NPY) else None
    Acoarse_pred = np.load(ACOARSE_PRED_NPY) if os.path.exists(ACOARSE_PRED_NPY) else None
    gfinal_npz = None
    if os.path.exists(GFINAL_PRED_NPZ):
        gfinal_npz = np.load(GFINAL_PRED_NPZ, allow_pickle=True)
    shybrid_mask = None
    if os.path.exists(S_HYBRID_NPZ):
        shybrid_mask = np.load(S_HYBRID_NPZ, allow_pickle=True)
    true_edges = {}
    if prich:
        for i,j,r,s in prich:
            true_edges[(i,j,r)] = float(s)
    if not true_edges and AW is not None:
        coo = AW.tocoo()
        for a,b,v in zip(coo.row.tolist(), coo.col.tolist(), coo.data.tolist()):
            true_edges[(int(a),int(b),0)] = float(v)
    candidate_list = []
    if shybrid_mask is not None:
        # expect boolean mask shape (N,N) or (N,N,R)
        mask = shybrid_mask
        if mask.ndim==2:
            rows, cols = np.where(mask)
            for a,b in zip(rows.tolist(), cols.tolist()):
                candidate_list.append((int(a),int(b),0))
        elif mask.ndim==3:
            R = mask.shape[2]
            for r in range(R):
                rows, cols = np.where(mask[:,:,r])
                for a,b in zip(rows.tolist(), cols.tolist()):
                    candidate_list.append((int(a),int(b),int(r)))
    elif true_edges:
        candidate_list = list(true_edges.keys())
    elif AW is not None:
        coo = AW.tocoo()
        candidate_list = [(int(a),int(b),0) for a,b in zip(coo.row.tolist(), coo.col.tolist())]
    else:
        raise RuntimeError("No candidate set available (need Prich, Shybrid, or AW).")
    # Load/predict scores for candidate set
    pred_scores = []
    target_scores = []
    labels = []
    # If Gfinal predictions available as dict-like with keys (i,j,r) -> score
    gfinal_dict = None
    if gfinal_npz is not None:
        # look for 'pred_dict' or arrays
        if 'pred_dict' in gfinal_npz.files:
            gfinal_dict = gfinal_npz['pred_dict'].item()
        else:
            # try arrays 'i','j','r','score'
            if all(k in gfinal_npz.files for k in ('i','j','r','score')):
                arr_i = gfinal_npz['i']; arr_j = gfinal_npz['j']; arr_r = gfinal_npz['r']; arr_s = gfinal_npz['score']
                gfinal_dict = {(int(a),int(b),int(c)): float(s) for a,b,c,s in zip(arr_i,arr_j,arr_r,arr_s)}
    assembled_preds = None
    def default_fcompat(i,j):
        return 1.0  # trivial fallback (no fine-grain model)
    if C is not None and Acoarse_pred is not None:
        # relation_count inferred from Acoarse_pred shape
        if Acoarse_pred.ndim==3:
            R = Acoarse_pred.shape[2]
        else:
            R = 1
        fcompat_fn = default_fcompat
        assembled_preds = {}
        N,K = C.shape
        for (i,j,r) in candidate_list:
            if Acoarse_pred.ndim==3:
                A_r = Acoarse_pred[:,:,r]
            else:
                A_r = Acoarse_pred
            ci = C[i]  # K
            cj = C[j]
            coarse_val = float(ci @ (A_r @ cj))
            local_val = float(fcompat_fn(i,j))
            assembled_preds[(i,j,r)] = coarse_val * local_val
        gfinal_dict = assembled_preds

    # Now for metrics compute lists
    for (i,j,r) in candidate_list:
        pred = float(gfinal_dict.get((i,j,r), 0.0)) if gfinal_dict is not None else 0.0
        target = float(true_edges.get((i,j,r), 0.0))
        label = 1 if target>0 else 0
        pred_scores.append(pred)
        target_scores.append(target)
        labels.append(label)

    # Compute RMSE
    rmse = rmse_over_S(pred_scores, target_scores)

    # Precision@K / Recall@K for K values
    K_values = [10, 50, 100, top_k_eval]
    pr_at_k = {}
    preds_with_pairs = list(zip(candidate_list, pred_scores))
    true_edge_map = {k:v for k,v in true_edges.items()}
    for K in K_values:
        p,r = precision_recall_at_k(preds_with_pairs, true_edge_map, K)
        pr_at_k[K] = {"precision": p, "recall": r}

    # AUC / AP
    try:
        auc, ap = auc_and_ap(pred_scores, labels)
    except Exception as e:
        auc, ap = float("nan"), float("nan")

    spearman_corr = None
    if pcoarse and Acoarse_pred is not None and C is not None:
        node_to_comm = np.argmax(C, axis=1)
        agg_true = {}
        for i,j,r,s in prich:
            ci = int(node_to_comm[i])
            cj = int(node_to_comm[j])
            agg_true.setdefault((ci,cj,r), []).append(s)
        true_vals = []
        pred_vals = []
        for (ci,cj,r), vals in agg_true.items():
            true_mean = float(np.mean(vals))
            # predicted coarse
            if Acoarse_pred.ndim==3:
                pred_mean = float(Acoarse_pred[ci,cj,r])
            else:
                pred_mean = float(Acoarse_pred[ci,cj])
            true_vals.append(true_mean)
            pred_vals.append(pred_mean)
        if len(true_vals)>=2:
            spearman_corr = float(spearmanr(true_vals, pred_vals).correlation)
        else:
            spearman_corr = float("nan")

    purity_metrics = {}
    communities_gt_path = os.path.join(data_dir, "communities.npy")
    if C is not None and os.path.exists(communities_gt_path):
        gt = np.load(communities_gt_path)
        pred_labels = np.argmax(C, axis=1)
        nmi = float(normalized_mutual_info_score(gt, pred_labels))
        ari = float(adjusted_rand_score(gt, pred_labels))
        ent = soft_assignment_entropy(C)
        purity_metrics = {"nmi": nmi, "ari": ari, "mean_entropy": float(ent.mean()), "median_entropy": float(np.median(ent))}
    elif C is not None:
        ent = soft_assignment_entropy(C)
        purity_metrics = {"mean_entropy": float(ent.mean()), "median_entropy": float(np.median(ent))}

    acyc_h = None
    if Acoarse_pred is not None:
        if Acoarse_pred.ndim==3:
            total = 0.0
            for r in range(Acoarse_pred.shape[2]):
                total += acyclicity_trace_exponential(np.asarray(Acoarse_pred[:,:,r], dtype=np.float64))
            acyc_h = float(total)
        else:
            acyc_h = float(acyclicity_trace_exponential(np.asarray(Acoarse_pred, dtype=np.float64)))

    # Build report
    report = {
        "rmse_candidate_S": rmse,
        "precision_recall_at_k": pr_at_k,
        "auc": auc,
        "average_precision": ap,
        "spearman_coarse_fine": spearman_corr,
        "clustering_metrics": purity_metrics,
        "acyclicity_h_coarse": acyc_h,
        "num_candidates_evaluated": len(candidate_list),
        "timestamp": time.time()
    }

    # Save
    with open(report_path, "w", encoding="utf-8") as f:
        json.dump(report, f, indent=2)

    # Print summary
    print("=== Evaluation Summary ===")
    print(f"Candidates evaluated: {len(candidate_list)}")
    print(f"RMSE (candidates): {rmse:.6f}")
    print(f"AUC: {auc:.6f}, AP: {ap:.6f}")
    print("Precision/Recall @ K:")
    for K,v in pr_at_k.items():
        print(f"  @ {K}: precision={v['precision']:.4f}, recall={v['recall']:.4f}")
    if spearman_corr is not None:
        print(f"Coarse<->Fine Spearman: {spearman_corr:.4f}")
    if purity_metrics:
        print("Clustering: ", purity_metrics)
    if acyc_h is not None:
        print(f"Acyclicity h(Acoarse) total: {acyc_h:.6f}")
    print(f"Report saved to: {report_path}")
    return report

# Run evaluation
if __name__ == "__main__":
    rep = evaluate_all(DATA_DIR, OUTPUT_REPORT, top_k_eval=100)


In [None]:
import os
import json
import time
import math
from typing import List, Tuple, Dict, Any
import numpy as np
import scipy.sparse as sp
from scipy import stats
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, auc
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import normalized_mutual_info_score as nmi_score
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.preprocessing import normalize
from sklearn.utils import resample
import networkx as nx

DATA_DIR = "./data" 
PRICH_JSONL = os.path.join(DATA_DIR, "prich.jsonl")          # gold fine-grained prior (node i,j,r,score)
CPRIOR_JSONL = os.path.join(DATA_DIR, "cprior.jsonl")        # teacher LLM prior 
CAND_PRED_NPZ = os.path.join(DATA_DIR, "pred_candidates.npz") # contains rows,cols,rels,scores_pred (arrays)
AW_NPZ = os.path.join(DATA_DIR, "AW.npz")                    # co-occurrence adjacency (sparse)
ACOARSE_NPY = os.path.join(DATA_DIR, "Acoarse.npy")          # coarse KxKxR learned scores 
ACOARSE_TARGET_NPY = os.path.join(DATA_DIR, "Acoarse_target.npy") # coarse target aggregated scores 
C_ASSIGN_NPY = os.path.join(DATA_DIR, "C_assign.npy")        # soft assignment C (N x K) from GNNcluster 
Z_EMB_NPY = os.path.join(DATA_DIR, "Z.npy")                  # structural embeddings (N x dz)
H_FINAL_NPY = os.path.join(DATA_DIR, "H_final.npy")          # final node embeddings (N x dmodel)
FACET_GATES_NPY = os.path.join(DATA_DIR, "gates.npy")        # gating values (N x M) optional
COMMUNITY_LABELS_JSON = os.path.join(DATA_DIR, "community_labels.json") # optional ground-truth community map

REPORT_OUT = os.path.join(DATA_DIR, "evaluation_report.json")

# ---------- helpers ----------
def assert_exists(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required artifact missing: {path}")

def load_prich(path: str) -> List[Dict[str, Any]]:
    assert_exists(path)
    out=[]
    with open(path,"r",encoding="utf-8") as f:
        for line in f:
            if not line.strip(): continue
            out.append(json.loads(line))
    return out

def load_candidates_npz(path: str):
    assert_exists(path)
    npz = np.load(path, allow_pickle=True)
    # expects arrays: rows, cols, rels, scores_pred
    for k in ("rows","cols","rels","scores_pred"):
        if k not in npz:
            raise KeyError(f"{path} missing array '{k}'")
    return npz["rows"].astype(int), npz["cols"].astype(int), npz["rels"].astype(int), npz["scores_pred"].astype(float)

def load_sparse(path: str):
    assert_exists(path)
    return sp.load_npz(path)

def load_npy_if_exists(path: str):
    return np.load(path) if os.path.exists(path) else None

def rmse(pred: np.ndarray, target: np.ndarray):
    return float(np.sqrt(np.mean((pred - target) ** 2)))

def precision_at_k(y_true_scores: np.ndarray, y_pred_scores: np.ndarray, k: int):
    idx = np.argsort(-y_pred_scores)[:k]
    return float(np.mean((y_true_scores[idx] > 0).astype(float))), float(np.sum((y_true_scores[idx] > 0))/max(1, np.sum(y_true_scores > 0)))

# ---------- evaluation functions ----------
def evaluate_edge_rmse(cand_rows, cand_cols, cand_rels, cand_pred_scores, prich_entries, AW=None):
    # build mapping for gold targets 
    gold_map = {}  # (i,j,r) -> score
    for e in prich_entries:
        key=(int(e["i"]), int(e["j"]), int(e.get("r",0)))
        gold_map[key]=float(e.get("score",1.0))
    targets = []
    preds = []
    for i,j,r,s in zip(cand_rows, cand_cols, cand_rels, cand_pred_scores):
        key=(int(i),int(j),int(r))
        if key in gold_map:
            t = gold_map[key]
        else:
            if AW is not None:
                val = AW.getrow(int(i)).tocoo()
                col_to_val = dict(zip(val.col.tolist(), val.data.tolist()))
                t = float(col_to_val.get(int(j), 0.0))
            else:
                t = 0.0
        preds.append(float(s))
        targets.append(float(t))
    preds = np.array(preds)
    targets = np.array(targets)
    return {"rmse": rmse(preds,targets), "num_pairs": len(preds)}

def evaluate_topk_per_relation(cand_rows, cand_cols, cand_rels, cand_pred_scores, prich_entries, K_list=(10,50,100)):
    per_rel = {}
    gold_map = {}
    for e in prich_entries:
        key=(int(e["i"]),int(e["j"]),int(e.get("r",0)))
        gold_map[key]=float(e.get("score",1.0))
    for i,j,r,s in zip(cand_rows,cand_cols,cand_rels,cand_pred_scores):
        per_rel.setdefault(int(r), {"preds":[], "golds":[]})
        per_rel[int(r)]["preds"].append(float(s))
        per_rel[int(r)]["golds"].append(float(gold_map.get((int(i),int(j),int(r)), 1.0 if False else 0.0)))
    reports={}
    for r,vals in per_rel.items():
        preds = np.array(vals["preds"]); golds = np.array(vals["golds"])
        reports[r]={}
        for K in K_list:
            p_at_k, recall_at_k = precision_at_k(golds, preds, K)
            reports[r][f"P@{K}"]=p_at_k
            reports[r][f"Recall@{K}"]=recall_at_k
    return reports

def evaluate_auc_pr(cand_rows, cand_cols, cand_rels, cand_pred_scores, prich_entries, pos_threshold=0.5):
    gold_map = {}
    for e in prich_entries:
        key=(int(e["i"]),int(e["j"]),int(e.get("r",0)))
        gold_map[key]=float(e.get("score",1.0))
    per_rel = {}
    for i,j,r,s in zip(cand_rows,cand_cols,cand_rels,cand_pred_scores):
        r=int(r)
        per_rel.setdefault(r, {"preds":[], "labels":[]})
        per_rel[r]["preds"].append(float(s))
        per_rel[r]["labels"].append(1 if gold_map.get((int(i),int(j),r),0.0) > pos_threshold else 0)
    out={}
    for r,vals in per_rel.items():
        y_true=np.array(vals["labels"]); y_score=np.array(vals["preds"])
        if y_true.sum() == 0 or y_true.sum() == len(y_true):
            out[r]={"roc_auc": None, "pr_auc": None}
            continue
        try:
            roc = roc_auc_score(y_true, y_score)
            pr = average_precision_score(y_true, y_score)
        except Exception:
            roc=None; pr=None
        out[r]={"roc_auc":roc, "pr_auc":pr}
    return out

def evaluate_coarsening_metrics(C_assign_path: str = C_ASSIGN_NPY, community_labels_path: str = COMMUNITY_LABELS_JSON):
    C = load_npy_if_exists(C_assign_path)
    if C is None:
        return {"available": False}
    # entropy per node
    eps=1e-12
    probs = np.clip(C, eps, 1.0)
    ent = -np.sum(probs * np.log(probs), axis=1)
    stats_ent = {"mean_entropy": float(ent.mean()), "median_entropy": float(np.median(ent))}
    if os.path.exists(community_labels_path):
        with open(community_labels_path,"r") as f:
            labels = json.load(f)  
        N = C.shape[0]
        y_true = np.zeros(N, dtype=int)
        for nid,lab in labels.items():
            idx = int(nid) if isinstance(nid,str) and nid.isdigit() else int(nid)
            y_true[idx]=int(lab)
        y_pred = np.argmax(C, axis=1)
        return {"available": True, "entropy_stats": stats_ent, "nmi": float(nmi_score(y_true,y_pred)), "ari": float(ari_score(y_true,y_pred))}
    return {"available": True, "entropy_stats": stats_ent}

def spearman_coarse_vs_fine(acoarse_path: str = ACOARSE_NPY, acoarse_target_path: str = ACOARSE_TARGET_NPY):
    Acoarse = load_npy_if_exists(acoarse_path)
    Atarget = load_npy_if_exists(acoarse_target_path)
    if Acoarse is None or Atarget is None:
        return {"available": False}
    flat_pred = Acoarse.reshape(-1)
    flat_target = Atarget.reshape(-1)
    rho, p = stats.spearmanr(flat_pred, flat_target)
    return {"available": True, "spearman_rho": float(rho), "pvalue": float(p)}

def acyclicity_h_matrix(A_mat: np.ndarray):
    import scipy.linalg as la
    S = A_mat * A_mat
    try:
        expm = la.expm(S)
        return float(np.trace(expm) - A_mat.shape[0])
    except Exception:
        return None

def evaluate_acyclicity_for_relations(acoarse_path: str = ACOARSE_NPY):
    Acoarse = load_npy_if_exists(acoarse_path)
    if Acoarse is None:
        return {"available": False}
    K = Acoarse.shape[0]
    R = Acoarse.shape[2]
    res = {}
    for r in range(R):
        h = acyclicity_h_matrix(Acoarse[:,:,r])
        res[r]= {"h": h}
    return {"available": True, "per_relation": res}

def evaluate_embedding_drift(Z_path: str = Z_EMB_NPY, H_final_path: str = H_FINAL_NPY):
    Z = load_npy_if_exists(Z_path)
    H = load_npy_if_exists(H_final_path)
    if Z is None or H is None:
        return {"available": False}
    if Z.shape[0] != H.shape[0]:
        raise ValueError("Z and H must have same number of nodes")
    # cosine similarities per node
    def cos_sim(a,b):
        an = a / (np.linalg.norm(a,axis=1,keepdims=True)+1e-12)
        bn = b / (np.linalg.norm(b,axis=1,keepdims=True)+1e-12)
        return np.sum(an*bn,axis=1)
    sims = cos_sim(Z, H)
    return {"available": True, "mean_cosine": float(np.mean(sims)), "median_cosine": float(np.median(sims))}

def spectral_similarity(A_coarse_path: str = ACOARSE_NPY, A_fine_candidate_npz: str = CAND_PRED_NPZ, top_k=20):
    Acoarse = load_npy_if_exists(A_coarse_path)
    if Acoarse is None:
        return {"available": False}
    A_sym = (Acoarse.sum(axis=2) + Acoarse.sum(axis=2).T) / 2.0
    L = np.diag(A_sym.sum(axis=1)) - A_sym
    eigs_coarse = np.linalg.eigvalsh(L)
    eigs_coarse_sorted = np.sort(eigs_coarse)[-top_k:]
    C = load_npy_if_exists(C_ASSIGN_NPY)
    if C is None:
        return {"available": False, "reason":"need C_assign to aggregate fine graph"}
    rows,cols,rels,scores = load_candidates_npz(A_fine_candidate_npz)
    N = C.shape[0]
    adj = sp.coo_matrix((scores, (rows,cols)), shape=(N,N)).toarray()
    # aggregate to community level: A_agg = C^T * adj * C
    A_agg = C.T @ adj @ C
    Lf = np.diag(A_agg.sum(axis=1)) - A_agg
    eigs_fine = np.linalg.eigvalsh(Lf)
    eigs_fine_sorted = np.sort(eigs_fine)[-top_k:]
    # compute relative L2 distance
    dist = np.linalg.norm(eigs_coarse_sorted - eigs_fine_sorted) / (np.linalg.norm(eigs_fine_sorted) + 1e-12)
    return {"available": True, "spectral_rel_l2": float(dist)}

def evaluate_fcompat_separability(cand_npz=CAND_PRED_NPZ, prich_path=PRICH_JSONL, fcompat_scores_npz=None):
    if fcompat_scores_npz is None or not os.path.exists(fcompat_scores_npz):
        return {"available": False}
    rows,cols,rels,preds = load_candidates_npz(cand_npz)
    fcompat = np.load(fcompat_scores_npz)["scores"]
    prich = load_prich(prich_path)
    gold_map = {(int(e["i"]),int(e["j"]),int(e.get("r",0))): float(e.get("score",1.0)) for e in prich}
    labels = []
    for i,j,r in zip(rows,cols,rels):
        labels.append(1 if gold_map.get((int(i),int(j),int(r)),0.0) > 0.5 else 0)
    labels = np.array(labels)
    pos = fcompat[labels==1]
    neg = fcompat[labels==0]
    if len(pos)==0 or len(neg)==0:
        return {"available": False, "reason": "no pos or neg"}
    tstat, p = stats.ttest_ind(pos, neg, equal_var=False)
    # Cohen's d
    d = (pos.mean() - neg.mean()) / (np.sqrt((pos.std()**2 + neg.std()**2)/2) + 1e-12)
    return {"available": True, "tstat":float(tstat), "p":float(p), "cohens_d":float(d), "pos_mean":float(pos.mean()), "neg_mean":float(neg.mean())}

def compute_motif_counts_sample(adj_matrix: np.ndarray, sample_nodes: List[int], motif="triangles"):
    # motif counts in induced subgraph of sample_nodes
    g = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph())
    sub = g.subgraph(sample_nodes)
    if motif=="triangles":
        # convert to undirected for triangle counting
        u = nx.Graph(sub.to_undirected())
        tri = sum(nx.triangles(u).values()) // 3
        return {"triangles": int(tri)}
    elif motif=="two_paths":
        # count length-2 paths: sum_over_nodes in-degree * out-degree
        two = 0
        for n in sub.nodes():
            two += sub.in_degree(n) * sub.out_degree(n)
        return {"two_paths": int(two)}
    return {}

# ---------- main evaluation runner ----------
def run_evaluation():
    start = time.time()
    report = {"created_at": time.time(), "notes": "CausGT-HS evaluation suite"}
    assert_exists(CAND_PRED_NPZ)
    assert_exists(PRICH_JSONL)
    cand_rows, cand_cols, cand_rels, cand_scores = load_candidates_npz(CAND_PRED_NPZ)
    prich = load_prich(PRICH_JSONL)
    AW_mat = load_sparse(AW_NPZ) if os.path.exists(AW_NPZ) else None

    report["edge_rmse"] = evaluate_edge_rmse(cand_rows, cand_cols, cand_rels, cand_scores, prich, AW=AW_mat)
    report["topk_per_relation"] = evaluate_topk_per_relation(cand_rows, cand_cols, cand_rels, cand_scores, prich, K_list=(10,50,100))
    report["auc_pr_per_relation"] = evaluate_auc_pr(cand_rows, cand_cols, cand_rels, cand_scores, prich, pos_threshold=0.5)
    report["coarsening"] = evaluate_coarsening_metrics()
    report["spearman_coarse_vs_fine"] = spearman_coarse_vs_fine()
    report["acyclicity_coarse"] = evaluate_acyclicity_for_relations()
    report["embedding_drift"] = evaluate_embedding_drift()
    report["spectral_similarity"] = spectral_similarity()
    assembled_adj = None
    assembled_path = os.path.join(DATA_DIR,"assembled_adj.npz")
    if os.path.exists(assembled_path):
        assembled_adj = sp.load_npz(assembled_path).toarray()
    else:
        N = load_npy_if_exists(Z_EMB_NPY)
        if N is None:
            assembled_adj = None
        else:
            Nn = N.shape[0]
            A = np.zeros((Nn,Nn), dtype=float)
            for i,j,_,s in zip(cand_rows,cand_cols,cand_rels,cand_scores):
                A[int(i),int(j)] = max(A[int(i),int(j)], float(s))
            assembled_adj = A
    if assembled_adj is not None:
        num_nodes = assembled_adj.shape[0]
        samp = list(np.random.choice(num_nodes, min(2000, num_nodes), replace=False))
        report["motifs_sample"] = compute_motif_counts_sample(assembled_adj, samp, motif="triangles")
        report["motifs_sample_two_paths"] = compute_motif_counts_sample(assembled_adj, samp, motif="two_paths")
    else:
        report["motifs_sample"] = {"available": False}

    # Save report
    with open(REPORT_OUT, "w", encoding="utf-8") as f:
        json.dump(report, f, indent=2)
    elapsed = time.time() - start
    print("Evaluation finished. Report saved to:", REPORT_OUT, " time(s):", round(elapsed,2))
    return report

if __name__=="__main__":
    r = run_evaluation()
    print(json.dumps(r, indent=2))


### 3.4.2.3 The CausGT-HS Encoder ($f_{\theta}$)

Here we fix the computational challenge of scaling it to large graphs....

#### Layer 1: LPA module (Learned Path Aggregator)

In [None]:
import json
from pathlib import Path
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------- config ----------
RNG_SEED = 20251127
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RNG_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LPA_DIR = Path("out/lpa_model")
ENCODER_CKPT = LPA_DIR / "encoder_finetuned.pth"
AGGREGATOR_CKPT = LPA_DIR / "aggregator_finetuned.pth"
AW_PATH = Path("out/graphs/AW_sparse.json")
NODE_EMB_PATH = LPA_DIR / "node_embeddings.pt"
H0_PATH = LPA_DIR / "H0_init.pt"
SAVE_OUT = Path("out/lpa_layer1_output.pt")

MAX_DEGREE_CLAMP = 128
MAX_MPOOL = 256
M_FACETS = 4
DEFAULT_NODE_EMB_DIM = 64
D_MODEL = 256
D_PATH = 256

# ---------- sanity ----------
if not ENCODER_CKPT.exists():
    raise FileNotFoundError(ENCODER_CKPT)
if not AGGREGATOR_CKPT.exists():
    raise FileNotFoundError(AGGREGATOR_CKPT)
if not AW_PATH.exists():
    raise FileNotFoundError(AW_PATH)

try:
    from model.lpa_components import PathTransformerEncoder, AttentionAggregator
except Exception as e:
    raise ImportError("Ensure model/lpa_components.py defines PathTransformerEncoder and AttentionAggregator") from e

encoder = PathTransformerEncoder(
    node_emb_dim=None, d_model=D_MODEL, nhead=8, num_layers=4, dim_feedforward=512, d_path=D_PATH, max_len=32
)
aggregator = AttentionAggregator(d_path=D_PATH, path_feat_dim=8, hidden_dim=256)

enc_state = torch.load(ENCODER_CKPT, map_location="cpu")
agg_state = torch.load(AGGREGATOR_CKPT, map_location="cpu")

if "encoder_state" in enc_state:
    encoder.load_state_dict(enc_state["encoder_state"], strict=True)
else:
    encoder.load_state_dict(enc_state, strict=True)

if "aggregator_state" in agg_state:
    aggregator.load_state_dict(agg_state["aggregator_state"], strict=True)
else:
    aggregator.load_state_dict(agg_state, strict=True)

encoder.to(DEVICE).eval()
aggregator.to(DEVICE).eval()
s_value_transform = aggregator.head.to(DEVICE).eval()

# ---------- utils ----------
def load_sparse_adj(path: Path) -> torch.sparse_coo_tensor:
    obj = json.loads(path.read_text())
    rows = torch.tensor(obj["rows"], dtype=torch.long)
    cols = torch.tensor(obj["cols"], dtype=torch.long)
    vals = torch.tensor(obj.get("vals", [1.0] * len(rows)), dtype=torch.float32)
    N = int(obj.get("N", int(max(rows.max().item(), cols.max().item()) + 1)))
    idx = torch.stack([rows, cols], dim=0)
    return torch.sparse_coo_tensor(indices=idx, values=vals, size=(N, N)).coalesce()

def adjacency_to_index_tensor(A_sparse: torch.sparse_coo_tensor, max_deg: int = MAX_DEGREE_CLAMP) -> torch.Tensor:
    N = A_sparse.size(0)
    rows = A_sparse.indices()[0].cpu().tolist()
    cols = A_sparse.indices()[1].cpu().tolist()
    neighbors: List[List[int]] = [[] for _ in range(N)]
    for r, c in zip(rows, cols):
        neighbors[r].append(c)
    out = torch.full((N, max_deg), -1, dtype=torch.long)
    for i in range(N):
        nb = neighbors[i][:max_deg]
        if nb:
            out[i, :len(nb)] = torch.tensor(nb, dtype=torch.long)
    return out.to(DEVICE)

def enumerate_paths_vectorized(neigh_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    N, H = neigh_idx.shape
    P1 = neigh_idx.clone()
    safe = neigh_idx.clone()
    safe[safe == -1] = 0
    second_hop = neigh_idx[safe]            # [N, H, H]
    mask_invalid = (neigh_idx == -1).unsqueeze(-1)
    second_hop = torch.where(mask_invalid, torch.full_like(second_hop, -1), second_hop)
    return P1, second_hop

def build_path_embedding_tensors(P1: torch.Tensor, P2: torch.Tensor, node_embeddings: torch.Tensor,
                                 max_mpool: int = MAX_MPOOL) -> Tuple[torch.Tensor, torch.Tensor]:
    N, H = P1.shape
    L = 3
    D = node_embeddings.size(1)
    device = node_embeddings.device
    paths_idx = torch.full((N, max_mpool, L), -1, dtype=torch.long, device=device)
    for i in range(N):
        cnt = 0
        for a_j in range(H):
            a = P1[i, a_j].item()
            if a == -1:
                continue
            for b_j in range(H):
                b = P2[i, a_j, b_j].item()
                if b == -1:
                    continue
                if cnt >= max_mpool:
                    break
                paths_idx[i, cnt, 0] = i
                paths_idx[i, cnt, 1] = a
                paths_idx[i, cnt, 2] = b
                cnt += 1
            if cnt >= max_mpool:
                break
    pad_emb = torch.zeros((1, D), device=device)
    emb_shift = torch.cat([pad_emb, node_embeddings.to(device)], dim=0)
    safe_paths_idx = paths_idx.clone()
    mask_neg = (safe_paths_idx == -1)
    safe_paths_idx = safe_paths_idx + 1
    safe_paths_idx[mask_neg] = 0
    paths_emb = emb_shift[safe_paths_idx]   # [N, mpool, L, D]
    return paths_emb.to(device), paths_emb.to(device)

# ---------- path encoder / aggregators ----------
class LPALayer1_PathEncoder:
    def __init__(self, transformer_path_model, device):
        self.tp = transformer_path_model.to(device)
        self.tp.eval()
        self.device = device

    @torch.no_grad()
    def encode(self, paths_tensor: torch.Tensor) -> torch.Tensor:
        N, M, L, D = paths_tensor.shape
        flat = paths_tensor.reshape(N * M, L, D).to(self.device)
        emb = self.tp(flat)
        return emb.reshape(N, M, -1)

class FacetAwareAggregator(nn.Module):
    def __init__(self, d_path, d_model, M, hidden_dim=256, s_value_transform_module=None):
        super().__init__()
        self.M = M
        self.W_h = nn.Linear(d_path, hidden_dim, bias=False)
        self.U = nn.Linear(d_model, hidden_dim, bias=False)
        self.b = nn.Parameter(torch.zeros(hidden_dim))
        self.w = nn.Parameter(torch.randn(hidden_dim))
        if s_value_transform_module is None:
            raise ValueError("Provide s_value_transform_module")
        self.MLP_s = s_value_transform_module
        self.d_path = d_path

    def forward(self, H0, path_embs):
        N, M, dm = H0.shape
        mpool = path_embs.shape[1]
        H0_exp = H0.unsqueeze(2).expand(N, M, mpool, dm)
        path_exp = path_embs.unsqueeze(1).expand(N, M, mpool, self.d_path)
        Wh_hp = self.W_h(path_exp)
        U_sk = self.U(H0_exp)
        z = torch.tanh(Wh_hp + U_sk + self.b)
        att_logits = torch.einsum("nmkd,d->nmk", z, self.w)
        att = F.softmax(att_logits, dim=2)
        svals = self.MLP_s(path_exp.reshape(N * M * mpool, self.d_path))
        svals = svals.reshape(N, M, mpool, -1)
        ctx = torch.sum(att.unsqueeze(-1) * svals, dim=2)
        return ctx

class LPALayer1_DualStreamAggregator(nn.Module):
    def __init__(self, d_path, d_model, M, s_value_transform_module):
        super().__init__()
        self.corr = FacetAwareAggregator(d_path, d_model, M, s_value_transform_module=s_value_transform_module)
        self.causal = FacetAwareAggregator(d_path, d_model, M, s_value_transform_module=s_value_transform_module)

    def forward(self, H0, path_corr, path_causal):
        return self.corr(H0, path_corr), self.causal(H0, path_causal)

class LPALayer1_Primer(nn.Module):
    def __init__(self, d_s, d_model):
        super().__init__()
        self.W_proj = nn.Linear(2 * d_s, d_model)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, H0, c_corr, c_causal):
        master = torch.cat([c_corr, c_causal], dim=-1)
        proj = self.W_proj(master)
        return self.ln(H0 + proj)

class CausGT_LPA_Layer1(nn.Module):
    def __init__(self, transformer_path, s_value_transform_module, d_path, d_model, M, device):
        super().__init__()
        self.path_encoder = LPALayer1_PathEncoder(transformer_path, device)
        self.agg = LPALayer1_DualStreamAggregator(d_path, d_model, M, s_value_transform_module)
        self.primer = LPALayer1_Primer(d_s=d_path, d_model=d_model)

    def forward(self, H0, paths_corr, paths_causal):
        pe_corr = self.path_encoder.encode(paths_corr)
        pe_causal = self.path_encoder.encode(paths_causal)
        c_corr, c_causal = self.agg(H0, pe_corr, pe_causal)
        return self.primer(H0, c_corr, c_causal)

# ---------- main ----------
def main():
    AW = load_sparse_adj(AW_PATH).to(DEVICE)
    AW = AW.coalesce()
    N = AW.size(0)

    neigh_idx = adjacency_to_index_tensor(AW, max_deg=MAX_DEGREE_CLAMP)
    P1, P2 = enumerate_paths_vectorized(neigh_idx)

    expected_node_dim = DEFAULT_NODE_EMB_DIM
    for attr in ("node_emb_dim", "input_dim", "d_token", "token_dim"):
        if hasattr(encoder, attr) and getattr(encoder, attr) is not None:
            expected_node_dim = int(getattr(encoder, attr))
            break
    if expected_node_dim == DEFAULT_NODE_EMB_DIM:
        for name, p in encoder.named_parameters():
            if p.dim() >= 2:
                expected_node_dim = p.size(1)
                break

    if NODE_EMB_PATH.exists():
        try:
            node_emb_obj = torch.load(NODE_EMB_PATH, map_location=DEVICE)
            if isinstance(node_emb_obj, dict) and "node_embeddings" in node_emb_obj:
                node_embeddings = node_emb_obj["node_embeddings"].to(DEVICE)
            elif isinstance(node_emb_obj, torch.Tensor):
                node_embeddings = node_emb_obj.to(DEVICE)
            else:
                raise RuntimeError("Unknown node embeddings format")
            if node_embeddings.size(0) != N:
                node_embeddings = torch.randn(N, expected_node_dim, device=DEVICE) * 0.01
        except Exception:
            node_embeddings = torch.randn(N, expected_node_dim, device=DEVICE) * 0.01
    else:
        node_embeddings = torch.randn(N, expected_node_dim, device=DEVICE) * 0.01

    paths_corr_tensor, paths_causal_tensor = build_path_embedding_tensors(P1, P2, node_embeddings, max_mpool=MAX_MPOOL)

    M = M_FACETS
    if H0_PATH.exists():
        try:
            h0_obj = torch.load(H0_PATH, map_location=DEVICE)
            if isinstance(h0_obj, dict) and "H0" in h0_obj:
                H0 = h0_obj["H0"].to(DEVICE)
            elif isinstance(h0_obj, torch.Tensor):
                H0 = h0_obj.to(DEVICE)
            else:
                raise RuntimeError("Unknown H0 format")
            if H0.size(0) != N or H0.size(1) != M or H0.size(2) != D_MODEL:
                H0 = torch.randn(N, M, D_MODEL, device=DEVICE) * 0.01
        except Exception:
            H0 = torch.randn(N, M, D_MODEL, device=DEVICE) * 0.01
    else:
        H0 = torch.randn(N, M, D_MODEL, device=DEVICE) * 0.01

    layer1 = CausGT_LPA_Layer1(
        transformer_path=encoder,
        s_value_transform_module=s_value_transform,
        d_path=D_PATH,
        d_model=D_MODEL,
        M=M,
        device=DEVICE
    ).to(DEVICE)
    layer1.eval()

    with torch.no_grad():
        H1 = layer1(H0=H0, paths_corr=paths_corr_tensor, paths_causal=paths_causal_tensor)
        SAVE_OUT.parent.mkdir(parents=True, exist_ok=True)
        torch.save({"H1": H1.detach().cpu()}, SAVE_OUT)
        print(f"Saved H1 to {SAVE_OUT} (shape {H1.shape})")

if __name__ == "__main__":
    main()


#### Layer 2...L: The Sparse Dual-Stream Causal-MoE-Tokenphormer

In [None]:

import json
from pathlib import Path
from typing import Optional, Tuple

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

RNG_SEED = 20251127
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RNG_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

LPA_DIR = Path("out/lpa_model")
ENCODER_CKPT = LPA_DIR / "encoder_finetuned.pth"
AGGREGATOR_CKPT = LPA_DIR / "aggregator_finetuned.pth"
AW_PATH = Path("out/graphs/AW_sparse.json")
NODE_EMB_PATH = LPA_DIR / "node_embeddings.pt"
H0_PATH = LPA_DIR / "H0_init.pt"
SAVE_OUT = Path("out/lpa_layer1_output.pt")   # contains {"H1": tensor}
SAVE_LAYER2 = Path("out/lpa_layer2_checkpoint.pth")

MAX_DEGREE_CLAMP = 128
MAX_MPOOL = 256
M_FACETS = 4
DEFAULT_NODE_EMB_DIM = 64
D_MODEL = 256
D_PATH = 256

# ---------- sanity ----------
if not SAVE_OUT.exists():
    raise FileNotFoundError(f"Layer1 output not found: {SAVE_OUT}")
if not AW_PATH.exists():
    raise FileNotFoundError(f"Adjacency file not found: {AW_PATH}")

# ---------- utilities from Layer1 ----------
def load_sparse_adj(path: Path) -> torch.sparse_coo_tensor:
    obj = json.loads(path.read_text())
    rows = torch.tensor(obj["rows"], dtype=torch.long)
    cols = torch.tensor(obj["cols"], dtype=torch.long)
    vals = torch.tensor(obj.get("vals", [1.0] * len(rows)), dtype=torch.float32)
    N = int(obj.get("N", int(max(rows.max().item(), cols.max().item()) + 1)))
    idx = torch.stack([rows, cols], dim=0)
    return torch.sparse_coo_tensor(indices=idx, values=vals, size=(N, N)).coalesce()

def adjacency_to_index_tensor(A_sparse: torch.sparse_coo_tensor, max_deg: int = MAX_DEGREE_CLAMP) -> torch.Tensor:
    N = A_sparse.size(0)
    rows = A_sparse.indices()[0].cpu().tolist()
    cols = A_sparse.indices()[1].cpu().tolist()
    neighbors = [[] for _ in range(N)]
    for r, c in zip(rows, cols):
        neighbors[r].append(c)
    out = torch.full((N, max_deg), -1, dtype=torch.long)
    for i in range(N):
        nb = neighbors[i][:max_deg]
        if nb:
            out[i, :len(nb)] = torch.tensor(nb, dtype=torch.long)
    return out.to(DEVICE)

# ---------- Model components ----------
class SimpleFFN(nn.Module):
    def __init__(self, d_model:int, d_ff:int, dropout:float=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )
    def forward(self, x):
        return self.net(x)

class MHSAWrapper(nn.Module):
    def __init__(self, d_model:int, nhead:int, dropout:float=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.ln = nn.LayerNorm(d_model)
    def forward(self, seq, attn_mask:Optional[torch.Tensor]=None):
        attn_out, _ = self.mha(seq, seq, seq, attn_mask=attn_mask, need_weights=False)
        return self.ln(seq + attn_out)

class ExpertModule(nn.Module):
    def __init__(self, d_model:int, nhead:int, aux_hidden:int=128, aux_out:int=1, dropout:float=0.1):
        super().__init__()
        self.mhsa = MHSAWrapper(d_model, nhead, dropout=dropout)
        self.aux_head = nn.Sequential(
            nn.Linear(d_model, aux_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(aux_hidden, aux_out)
        )
    def forward(self, seq):
        seq_out = self.mhsa(seq)
        cls = seq_out[:, 0, :]
        aux = self.aux_head(cls)
        return seq_out, aux

def topk_sparse_weights(logits: torch.Tensor, k: int, temp: float):
    probs = F.softmax(logits / temp, dim=-1)
    topk_vals, topk_idx = torch.topk(probs, k=k, dim=-1)
    B, E = probs.shape
    device = logits.device
    weights = torch.zeros_like(probs)
    arange_b = torch.arange(B, device=device).unsqueeze(1).expand(-1, k)
    weights[arange_b.reshape(-1), topk_idx.reshape(-1)] = topk_vals.reshape(-1)
    return weights, topk_idx, probs

class SparseCausalMoE(nn.Module):
    def __init__(self, d_model:int, nhead:int, n_experts:int=8, top_k:int=2,
                 router_hidden:int=128, temp:float=1.0, dropout:float=0.1, aux_out:int=1):
        super().__init__()
        self.E = n_experts
        self.top_k = top_k
        self.temp = temp
        self.router = nn.Sequential(
            nn.Linear(d_model * 2, router_hidden),
            nn.ReLU(inplace=True),
            nn.Linear(router_hidden, n_experts),
        )
        self.experts = nn.ModuleList([ExpertModule(d_model, nhead, aux_hidden=128, aux_out=aux_out, dropout=dropout)
                                      for _ in range(n_experts)])
        self.layernorm = nn.LayerNorm(d_model)
    def forward(self, seq, self_token_vec):
        pooled = seq.mean(dim=1)
        router_in = torch.cat([self_token_vec, pooled], dim=-1)
        logits = self.router(router_in)
        weights, topk_idx, probs = topk_sparse_weights(logits, k=self.top_k, temp=self.temp)
        expert_outputs = []
        aux_logits = []
        # run all experts
        for e in range(self.E):
            out_e, aux_e = self.experts[e](seq)
            expert_outputs.append(out_e)
            aux_logits.append(aux_e)
        expert_stack = torch.stack(expert_outputs, dim=0).permute(1,0,2,3)  # [B, E, S, d]
        w = weights.unsqueeze(-1).unsqueeze(-1)
        aggregated = (w * expert_stack).sum(dim=1)
        aggregated = self.layernorm(aggregated)
        aux_stack = torch.stack(aux_logits, dim=0).permute(1,0,2)  # [B, E, aux_out]
        return aggregated, aux_stack, logits, probs

class CorrelationalStream(nn.Module):
    def __init__(self, d_model:int, nhead:int, dropout:float=0.1):
        super().__init__()
        self.mhsa = MHSAWrapper(d_model, nhead, dropout=dropout)
    def forward(self, seq, attn_mask:Optional[torch.Tensor]=None):
        return self.mhsa(seq, attn_mask=attn_mask)

class TokenphormerLayer(nn.Module):
    def __init__(self,
                 d_model:int=256,
                 nhead:int=8,
                 n_experts:int=8,
                 top_k_experts:int=2,
                 k_neighbors:int=16,
                 ff_hidden:int=1024,
                 dropout:float=0.1,
                 router_temp:float=1.0):
        super().__init__()
        self.d_model = d_model
        self.k_neighbors = k_neighbors
        self.correl_stream = CorrelationalStream(d_model, nhead, dropout=dropout)
        self.causal_moe = SparseCausalMoE(d_model, nhead, n_experts, top_k_experts, router_hidden=256,
                                          temp=router_temp, dropout=dropout, aux_out=1)
        self.gate_proj = nn.Linear(d_model, 1)
        self.ffn = SimpleFFN(d_model, ff_hidden, dropout=dropout)
        self.ln_final = nn.LayerNorm(d_model)

    def build_hybrid_sequence(self,
                              H_prev_flat: torch.Tensor,
                              H_all: torch.Tensor,
                              neighbor_idx: torch.LongTensor,
                              H_coarse: Optional[torch.Tensor] = None,
                              text_tokens: Optional[torch.Tensor] = None):
        device = H_prev_flat.device
        B = H_prev_flat.size(0)
        d = self.d_model
        self_token = H_prev_flat.unsqueeze(1)
        N_nodes, M_facets, _ = H_all.shape
        M = M_facets
        assert B == N_nodes * M, "shape mismatch"
        node_ids = torch.arange(B, device=device) // M
        if H_coarse is not None:
            hcontext = H_coarse.mean(dim=0, keepdim=True).expand(B, -1).unsqueeze(1)
        else:
            hcontext = torch.zeros(B, 1, d, device=device)
        k_total = neighbor_idx.size(1)
        cand_nodes = neighbor_idx[node_ids]  # [B, k_total]
        cand_nodes_exp = cand_nodes.unsqueeze(-1).expand(-1, -1, M)
        H_all_expand = H_all[cand_nodes_exp]  # [B, k_total, M, d]
        neighbor_facets = H_all_expand.reshape(B, k_total * M, d)
        attn_scores = (self_token @ neighbor_facets.transpose(1,2)).squeeze(1)
        Ksel = min(self.k_neighbors, attn_scores.size(1))
        topk_vals, topk_idx = torch.topk(attn_scores, k=Ksel, dim=-1)
        arange_b = torch.arange(B, device=device).unsqueeze(1).expand(-1, Ksel)
        selected_neighbors = neighbor_facets[arange_b.reshape(-1), topk_idx.reshape(-1)].reshape(B, Ksel, d)
        if text_tokens is not None:
            text_tokens_exp = text_tokens[node_ids]
        else:
            text_tokens_exp = torch.zeros(B, 0, d, device=device)
        seq = torch.cat([self_token, hcontext, selected_neighbors, text_tokens_exp], dim=1)
        return seq

    def forward(self,
                H_prev: torch.Tensor,
                neighbor_idx: torch.LongTensor,
                H_coarse: Optional[torch.Tensor] = None,
                text_tokens: Optional[torch.Tensor] = None):
        N, M, d = H_prev.shape
        device = H_prev.device
        B = N * M
        H_prev_flat = H_prev.reshape(B, d)
        seq = self.build_hybrid_sequence(H_prev_flat, H_prev, neighbor_idx, H_coarse, text_tokens)
        corr_out = self.correl_stream(seq)
        self_token_vec = seq[:, 0, :]
        causal_out, aux_logits, router_logits, router_probs = self.causal_moe(seq, self_token_vec)
        h_corr = corr_out[:, 0, :]
        h_causal = causal_out[:, 0, :]
        lam = torch.sigmoid(self.gate_proj(H_prev_flat)).squeeze(-1)
        hfused = lam.unsqueeze(-1) * h_corr + (1.0 - lam).unsqueeze(-1) * h_causal
        hffn = self.ffn(hfused)
        hnext_flat = self.ln_final(hfused + hffn)
        H_next = hnext_flat.reshape(N, M, d)
        return H_next, router_logits, router_probs, aux_logits

# ---------- training utilities & losses ----------
def load_layer1_H1(path: Path) -> torch.Tensor:
    obj = torch.load(path, map_location="cpu")
    if isinstance(obj, dict) and "H1" in obj:
        H1 = obj["H1"]
    elif isinstance(obj, torch.Tensor):
        H1 = obj
    else:
        raise RuntimeError("Unknown H1 format in SAVE_OUT")
    return H1

def kl_to_uniform_mean(probs: torch.Tensor):
    eps = 1e-9
    mean_p = probs.mean(dim=0)  # [E]
    E = mean_p.size(0)
    kl = (mean_p * (torch.log(mean_p + eps) - math.log(1.0 / E))).sum()
    return kl

def train_loop(layer: TokenphormerLayer,
               H_init: torch.Tensor,
               neighbor_idx: torch.LongTensor,
               epochs: int = 20,
               lr: float = 3e-4,
               weight_decay: float = 1e-6,
               save_path: Path = SAVE_LAYER2):
    layer.train()
    layer.to(DEVICE)
    H = H_init.to(DEVICE)
    optimizer = torch.optim.AdamW(layer.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    N, M, d = H.shape

    for epoch in range(1, epochs + 1):
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            H_next, router_logits, router_probs, aux_logits = layer(H, neighbor_idx)
            # Self-supervised reconstruction loss: encouraged stability (H_next close to H)
            recon_loss = F.mse_loss(H_next, H)
            # router load-balance loss: KL(mean_probs || uniform)
            lb_loss = kl_to_uniform_mean(router_probs)
            # aux regularization (small L2 on aux logits)
            aux_reg = aux_logits.pow(2).mean()
            total_loss = recon_loss + 0.1 * lb_loss + 1e-3 * aux_reg
        scaler.scale(total_loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(layer.parameters(), max_norm=5.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        H = H_next.detach()  
        if epoch % 1 == 0:
            print(f"Epoch {epoch:03d} | total_loss {total_loss.item():.6f} recon {recon_loss.item():.6f} lb {lb_loss.item():.6f} aux {aux_reg.item():.6e}")
        # save checkpoint each epoch
        ckpt = {
            "epoch": epoch,
            "model_state": layer.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "H_curr": H.cpu(),
        }
        save_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(ckpt, save_path)
    return layer, H

# ---------- main ----------
def main():
    AW = load_sparse_adj(AW_PATH).to(DEVICE).coalesce()
    neigh_idx = adjacency_to_index_tensor(AW, max_deg=MAX_DEGREE_CLAMP)  # [N, max_deg]

    H1 = load_layer1_H1(SAVE_OUT)  # assumed [N, M, D_MODEL]
    if H1.dim() != 3:
        raise RuntimeError("Expected H1 of shape [N, M, D_MODEL]")
    N, M, D = H1.shape
    if D != D_MODEL:
        proj = nn.Linear(D, D_MODEL).to(DEVICE)
        H1 = proj(H1.to(DEVICE)).detach().cpu()

    desired_k_neighbors = min(32, neigh_idx.size(1))
    neigh_idx = neigh_idx[:, :desired_k_neighbors].contiguous()

    layer2 = TokenphormerLayer(d_model=D_MODEL,
                               nhead=8,
                               n_experts=8,
                               top_k_experts=2,
                               k_neighbors=min(16, desired_k_neighbors),
                               ff_hidden=1024,
                               dropout=0.1,
                               router_temp=1.0)

    neighbor_idx = neigh_idx.to(DEVICE)

    # train
    epochs = 20
    trained_layer, H_final = train_loop(layer2, H1, neighbor_idx, epochs=epochs, lr=3e-4)

    # save final H and model
    out_dir = Path("out")
    out_dir.mkdir(parents=True, exist_ok=True)
    torch.save({"H2": H_final.detach().cpu()}, out_dir / "lpa_layer2_output.pt")
    torch.save({"model_state": trained_layer.state_dict()}, out_dir / "lpa_layer2_model.pth")
    print("Saved layer2 outputs and model.")

if __name__ == "__main__":
    main()


### 3.4.2.4 The Amortized Proposer Network ($Q_{\phi}$)

In [None]:
import math
from pathlib import Path
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch_energy.samplers import LangevinSampler

# ---------- config ----------
RNG_SEED = 20251127
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RNG_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

OUT_DIR = Path("out")
LAYER2_OUT = OUT_DIR / "lpa_layer2_output.pt"
LAYER1_OUT = OUT_DIR / "lpa_layer1_output.pt"
SAVE_PROPOSER = OUT_DIR / "proposer_qphi_checkpoint.pth"

# hyperparams
R = 6
d_rank = 16
d_q = 64
d_k = 64
d_e = 32
film_hidden = 128
alpha_energy = 0.1
alpha_distill = 1.0
alpha_sparse = 1e-2
sparse_beta = 1.0
proposer_lr = 3e-4
ebm_lr = 3e-4
epochs = 30
mcmc_steps = 10
mcmc_step_size = 0.1
mcmc_noise = 1e-2
ema_tau = 0.995
use_rel_loop = False  # set True if memory tight

# ---------- loaders ----------
def load_HL() -> torch.Tensor:
    if LAYER2_OUT.exists():
        obj = torch.load(LAYER2_OUT, map_location="cpu")
        k = "H2" if "H2" in obj else next(iter(obj.keys()))
        H = obj[k]
    elif LAYER1_OUT.exists():
        obj = torch.load(LAYER1_OUT, map_location="cpu")
        k = "H1" if "H1" in obj else next(iter(obj.keys()))
        H = obj[k]
    else:
        raise FileNotFoundError("No Layer outputs found in out/")
    if not isinstance(H, torch.Tensor):
        raise RuntimeError("Loaded H is not a tensor")
    return H.to(DEVICE)

# ---------- proposer modules ----------
class FiLMGenerator(nn.Module):
    def __init__(self, e_dim:int, hidden_dim:int, out_dim:int):
        super().__init__()
        self.gam = nn.Sequential(nn.Linear(e_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, out_dim))
        self.bet = nn.Sequential(nn.Linear(e_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, out_dim))
    def forward(self, e):
        return self.gam(e), self.bet(e)

class SharedG(nn.Module):
    def __init__(self, x_dim:int, hidden_dim:int, out_dim:int):
        super().__init__()
        self.W1 = nn.Linear(x_dim, hidden_dim)
        self.act = nn.ReLU()
        self.W2 = nn.Linear(hidden_dim, out_dim)
        self.b1 = nn.Parameter(torch.zeros(hidden_dim))
        self.b2 = nn.Parameter(torch.zeros(out_dim))
    def forward(self, x, gamma, beta):
        h = self.W1(x) + self.b1
        h = self.act(h)
        h = h * gamma + beta
        return self.W2(h) + self.b2

class ProposerNet(nn.Module):
    def __init__(self, R:int, d_out:int, d_rank:int, d_q:int, d_k:int, d_e:int, film_hidden:int,
                 use_rel_loop:bool=False):
        super().__init__()
        self.R = R
        self.d_out = d_out
        self.d_rank = d_rank
        self.d_q = d_q
        self.d_k = d_k
        self.d_e = d_e
        self.use_rel_loop = use_rel_loop

        self.q_rel = nn.Parameter(torch.randn(R, d_q) * 0.02)
        self.WQ = nn.Linear(d_q, d_k, bias=False)
        self.WK = nn.Linear(d_out, d_k, bias=False)

        self.eU = nn.Parameter(torch.randn(R, d_e) * 0.02)
        self.eV = nn.Parameter(torch.randn(R, d_e) * 0.02)

        self.film_U = FiLMGenerator(d_e, film_hidden, film_hidden)
        self.film_V = FiLMGenerator(d_e, film_hidden, film_hidden)
        self.g_shared = SharedG(x_dim=d_out, hidden_dim=film_hidden, out_dim=d_rank)

        self.b_rel = nn.Parameter(torch.zeros(R))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.WQ.weight)
        nn.init.xavier_uniform_(self.WK.weight)
        nn.init.xavier_uniform_(self.g_shared.W1.weight)
        nn.init.xavier_uniform_(self.g_shared.W2.weight)

    def relation_pooling(self, H: torch.Tensor) -> torch.Tensor:
        N, M, d = H.shape
        K = self.WK(H.view(N*M, d)).view(N, M, self.d_k)
        Q = self.WQ(self.q_rel)  # [R, d_k]
        Q_exp = Q.unsqueeze(1).unsqueeze(2)      # [R,1,1,d_k]
        K_exp = K.unsqueeze(0)                   # [1,N,M,d_k]
        att_logits = (Q_exp * K_exp).sum(dim=-1) / math.sqrt(self.d_k)  # [R,N,M]
        att = F.softmax(att_logits, dim=-1)
        pooled = (att.unsqueeze(-1) * H.unsqueeze(0)).sum(dim=2)  # [R,N,d_out]
        return pooled

    def hyper_generate_UV(self, pooled: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        R, N, d = pooled.shape
        device = pooled.device
        if self.use_rel_loop:
            U = torch.zeros(R, N, self.d_rank, device=device)
            V = torch.zeros(R, N, self.d_rank, device=device)
            for r in range(R):
                h_r = pooled[r]
                gamma_u, beta_u = self.film_U(self.eU[r:r+1].expand(N, -1))
                gamma_v, beta_v = self.film_V(self.eV[r:r+1].expand(N, -1))
                U[r] = self.g_shared(h_r, gamma_u, beta_u)
                V[r] = self.g_shared(h_r, gamma_v, beta_v)
            return U, V
        pooled_flat = pooled.reshape(R*N, d)
        eU_exp = self.eU.unsqueeze(1).expand(-1, N, -1).reshape(R*N, self.d_e)
        eV_exp = self.eV.unsqueeze(1).expand(-1, N, -1).reshape(R*N, self.d_e)
        gamma_u, beta_u = self.film_U(eU_exp)
        gamma_v, beta_v = self.film_V(eV_exp)
        outU = self.g_shared(pooled_flat, gamma_u, beta_u)
        outV = self.g_shared(pooled_flat, gamma_v, beta_v)
        U = outU.view(R, N, self.d_rank)
        V = outV.view(R, N, self.d_rank)
        return U, V

    def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        pooled = self.relation_pooling(H)  # [R,N,d_out]
        U, V = self.hyper_generate_UV(pooled)  # [R,N,d_rank]
        R_, N, dr = U.shape
        A_logits = torch.zeros(R_, N, N, device=H.device)
        for r in range(R_):
            Ur = U[r]
            Vr = V[r]
            A_logits[r] = Ur @ Vr.t() + self.b_rel[r]
        A_probs = torch.sigmoid(A_logits)
        return U, V, A_logits, A_probs

# ---------- simple EBM ----------
class EnergyModel(nn.Module):
    def __init__(self, R:int, d_in:int, d_hidden:int=256, d_latent:int=64):
        super().__init__()
        self.R = R
        self.encoder = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_latent),
            nn.ReLU()
        )
        self.Rmats = nn.Parameter(torch.randn(R, d_latent, d_latent) * 0.01)
        self.head = nn.Sequential(nn.Linear(R, 128), nn.ReLU(), nn.Linear(128, 1))
    def forward(self, H: torch.Tensor, A_probs: torch.Tensor) -> torch.Tensor:
        N, M, d = H.shape
        Hpool = H.mean(dim=1)
        z = self.encoder(Hpool)
        energies = []
        for r in range(self.R):
            Rmat = self.Rmats[r]
            S = z @ Rmat @ z.t()
            S_sig = torch.sigmoid(S)
            diff = (A_probs[r] - S_sig).abs().mean()
            energies.append(diff)
        stats = torch.stack(energies)
        out = self.head(stats.unsqueeze(0))
        return out.squeeze()

# ---------- EMA helper ----------
def ema_update(model_ema: nn.Module, model: nn.Module, tau: float):
    for p_ema, p in zip(model_ema.parameters(), model.parameters()):
        p_ema.data.mul_(tau).add_(p.data, alpha=(1.0 - tau))

# ---------- training ----------
def train():
    H = load_HL()
    N, M, d_found = H.shape
    proposer = ProposerNet(R=R, d_out=d_found, d_rank=d_rank, d_q=d_q, d_k=d_k, d_e=d_e,
                           film_hidden=film_hidden, use_rel_loop=use_rel_loop).to(DEVICE)
    ebm = EnergyModel(R=R, d_in=d_found).to(DEVICE)
    ebm_ema = EnergyModel(R=R, d_in=d_found).to(DEVICE)
    ebm_ema.load_state_dict(ebm.state_dict())

    opt_prop = torch.optim.AdamW(proposer.parameters(), lr=proposer_lr, weight_decay=1e-6)
    opt_ebm = torch.optim.AdamW(ebm.parameters(), lr=ebm_lr, weight_decay=1e-6)

    sampler = LangevinSampler(steps=mcmc_steps, step_size=mcmc_step_size, noise=mcmc_noise)

    for epoch in range(1, epochs + 1):
        proposer.train()
        ebm.train()

        U, V, A_logits, A0 = proposer(H)

        ebm_ema.eval()
        L_energy = ebm_ema(H, A0)
        if L_energy.dim() != 0:
            L_energy = L_energy.mean()

        with torch.no_grad():
            A_star = sampler.sample(energy_fn=lambda X: ebm(H, X), X_init=A0)

        L_distill = F.l1_loss(A0, A_star.detach(), reduction="mean")
        L_sparse = A0.abs().mean()

        total_prop_loss = alpha_energy * L_energy + alpha_distill * L_distill + alpha_sparse * (sparse_beta * L_sparse)

        opt_prop.zero_grad()
        total_prop_loss.backward()
        torch.nn.utils.clip_grad_norm_(proposer.parameters(), max_norm=5.0)
        opt_prop.step()

        with torch.no_grad():
            A_neg = torch.bernoulli(torch.rand_like(A0)).to(DEVICE)
        E_pos = ebm(H, A_star.detach())
        E_neg = ebm(H, A_neg)
        margin = 0.1
        hinge = F.relu(margin + E_pos - E_neg)
        ebm_loss = E_pos + hinge

        opt_ebm.zero_grad()
        ebm_loss.backward()
        torch.nn.utils.clip_grad_norm_(ebm.parameters(), max_norm=5.0)
        opt_ebm.step()

        ema_update(ebm_ema, ebm, ema_tau)

        if epoch % 1 == 0:
            print(f"Epoch {epoch:03d} | prop_loss {total_prop_loss.item():.6f} (E {L_energy.item():.6f} D {L_distill.item():.6f} S {L_sparse.item():.6f}) | ebm_loss {ebm_loss.item():.6f}")

        if epoch % 5 == 0 or epoch == epochs:
            ckpt = {
                "epoch": epoch,
                "proposer": proposer.state_dict(),
                "ebm": ebm.state_dict(),
                "ebm_ema": ebm_ema.state_dict(),
                "U": U.detach().cpu(),
                "V": V.detach().cpu(),
                "A0": A0.detach().cpu()
            }
            SAVE_PROPOSER.parent.mkdir(parents=True, exist_ok=True)
            torch.save(ckpt, SAVE_PROPOSER)

    final = {
        "proposer": proposer.state_dict(),
        "ebm": ebm.state_dict(),
        "ebm_ema": ebm_ema.state_dict()
    }
    torch.save(final, SAVE_PROPOSER.with_suffix(".final.pth"))
    print("Done. Saved proposer and ebm checkpoints.")

if __name__ == "__main__":
    train()


### 3.4.2.5 Self-Supervised Training: The Causal Energy Curriculum

In [None]:
import math
from pathlib import Path
from typing import Optional, Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch_energy.samplers import LangevinSampler

# ---------- config  ----------
RNG_SEED = 20251127
torch.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RNG_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

OUT = Path("out")
LAYER2_OUT = OUT / "lpa_layer2_output.pt"   
LAYER1_OUT = OUT / "lpa_layer1_output.pt"
C_PATH = OUT / "C_soft.pt"                  
H_AUG_PATH = OUT / "H_aug.pt"               

SAVE_CKPT = OUT / "causal_energy_curriculum_ckpt.pth"

# model hyperparams
d_rank = 16
proposer_lr = 3e-4
ebm_lr = 3e-4
epochs = 40
mcmc_steps_pos = 12
mcmc_step_size = 0.08
mcmc_noise = 1e-2
sld_ascent_steps = 6
sld_step = 0.05
sld_noise = 1e-2
ema_tau = 0.995
negatives_easy = 8
negatives_hard = 8

# loss weights 
alpha_energy = 0.1
alpha_distill = 1.0
alpha_sparse = 1e-2
lambda_coarse = 1.0
lambda_fine = 1.0
lambda_consistency = 10.0
lambda_dag = 10.0
lambda_sparse_coarse = 1.0
mu_outdeg = 2.0
lambda_deg = 0.1
invariance_enable = True  

# ---------- utilities ----------
def load_H():
    if LAYER2_OUT.exists():
        o = torch.load(LAYER2_OUT, map_location="cpu")
        k = "H2" if "H2" in o else next(iter(o.keys()))
        H = o[k]
    elif LAYER1_OUT.exists():
        o = torch.load(LAYER1_OUT, map_location="cpu")
        k = "H1" if "H1" in o else next(iter(o.keys()))
        H = o[k]
    else:
        raise FileNotFoundError("No layer outputs found in out/ (expected lpa_layer2_output.pt or lpa_layer1_output.pt)")
    if not isinstance(H, torch.Tensor):
        raise RuntimeError("Loaded H is not a tensor")
    return H.to(DEVICE)

def load_C():
    if not C_PATH.exists():
        raise FileNotFoundError(f"Soft assignment matrix C not found at {C_PATH}")
    C = torch.load(C_PATH, map_location="cpu")
    if not isinstance(C, torch.Tensor):
        raise RuntimeError("Loaded C is not a tensor")
    return C.to(DEVICE)

def load_H_aug():
    if not H_AUG_PATH.exists():
        raise FileNotFoundError(f"Augmented environment H not found at {H_AUG_PATH} (required when invariance enabled)")
    H = torch.load(H_AUG_PATH, map_location="cpu")
    if not isinstance(H, torch.Tensor):
        raise RuntimeError("Loaded H_aug is not a tensor")
    return H.to(DEVICE)

# ---------- proposer (Q_phi) (FiLM hypernetwork) ----------
class FiLMGenerator(nn.Module):
    def __init__(self, e_dim:int, hidden_dim:int, out_dim:int):
        super().__init__()
        self.gam = nn.Sequential(nn.Linear(e_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, out_dim))
        self.bet = nn.Sequential(nn.Linear(e_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, out_dim))
    def forward(self, e):
        return self.gam(e), self.bet(e)

class SharedG(nn.Module):
    def __init__(self, x_dim:int, hidden_dim:int, out_dim:int):
        super().__init__()
        self.W1 = nn.Linear(x_dim, hidden_dim)
        self.act = nn.ReLU()
        self.W2 = nn.Linear(hidden_dim, out_dim)
        self.b1 = nn.Parameter(torch.zeros(hidden_dim))
        self.b2 = nn.Parameter(torch.zeros(out_dim))
    def forward(self, x, gamma, beta):
        h = self.W1(x) + self.b1
        h = self.act(h)
        h = h * gamma + beta
        return self.W2(h) + self.b2

class ProposerNet(nn.Module):
    def __init__(self, R:int, d_out:int, d_rank:int, d_q:int=64, d_k:int=64, d_e:int=32, film_hidden:int=128, use_rel_loop:bool=False):
        super().__init__()
        self.R = R
        self.d_out = d_out
        self.d_rank = d_rank
        self.d_q = d_q
        self.d_k = d_k
        self.d_e = d_e
        self.use_rel_loop = use_rel_loop

        self.q_rel = nn.Parameter(torch.randn(R, d_q) * 0.02)
        self.WQ = nn.Linear(d_q, d_k, bias=False)
        self.WK = nn.Linear(d_out, d_k, bias=False)

        self.eU = nn.Parameter(torch.randn(R, d_e) * 0.02)
        self.eV = nn.Parameter(torch.randn(R, d_e) * 0.02)

        self.film_U = FiLMGenerator(d_e, film_hidden, film_hidden)
        self.film_V = FiLMGenerator(d_e, film_hidden, film_hidden)
        self.g_shared = SharedG(x_dim=d_out, hidden_dim=film_hidden, out_dim=d_rank)

        self.b_rel = nn.Parameter(torch.zeros(R))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.WQ.weight)
        nn.init.xavier_uniform_(self.WK.weight)
        nn.init.xavier_uniform_(self.g_shared.W1.weight)
        nn.init.xavier_uniform_(self.g_shared.W2.weight)

    def relation_pooling(self, H: torch.Tensor) -> torch.Tensor:
        N, M, d = H.shape
        K = self.WK(H.view(N*M, d)).view(N, M, self.d_k)
        Q = self.WQ(self.q_rel)  # [R, d_k]
        Q_exp = Q.unsqueeze(1).unsqueeze(2)
        K_exp = K.unsqueeze(0)
        att_logits = (Q_exp * K_exp).sum(dim=-1) / math.sqrt(self.d_k)
        att = F.softmax(att_logits, dim=-1)
        pooled = (att.unsqueeze(-1) * H.unsqueeze(0)).sum(dim=2)
        return pooled  # [R, N, d_out]

    def hyper_generate_UV(self, pooled: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        R, N, d = pooled.shape
        device = pooled.device
        if self.use_rel_loop:
            U = torch.zeros(R, N, self.d_rank, device=device)
            V = torch.zeros(R, N, self.d_rank, device=device)
            for r in range(R):
                h_r = pooled[r]
                gamma_u, beta_u = self.film_U(self.eU[r:r+1].expand(N, -1))
                gamma_v, beta_v = self.film_V(self.eV[r:r+1].expand(N, -1))
                U[r] = self.g_shared(h_r, gamma_u, beta_u)
                V[r] = self.g_shared(h_r, gamma_v, beta_v)
            return U, V
        pooled_flat = pooled.reshape(R*N, d)
        eU_exp = self.eU.unsqueeze(1).expand(-1, N, -1).reshape(R*N, self.d_e)
        eV_exp = self.eV.unsqueeze(1).expand(-1, N, -1).reshape(R*N, self.d_e)
        gamma_u, beta_u = self.film_U(eU_exp)
        gamma_v, beta_v = self.film_V(eV_exp)
        outU = self.g_shared(pooled_flat, gamma_u, beta_u)
        outV = self.g_shared(pooled_flat, gamma_v, beta_v)
        U = outU.view(R, N, self.d_rank)
        V = outV.view(R, N, self.d_rank)
        return U, V

    def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        pooled = self.relation_pooling(H)
        U, V = self.hyper_generate_UV(pooled)
        R_, N, dr = U.shape
        A_logits = torch.zeros(R_, N, N, device=H.device)
        for r in range(R_):
            Ur = U[r]
            Vr = V[r]
            A_logits[r] = Ur @ Vr.t() + self.b_rel[r]
        A_probs = torch.sigmoid(A_logits)
        return U, V, A_logits, A_probs

# ---------- Hierarchical EBM (E_total) ----------
class HierarchicalEBM(nn.Module):
    def __init__(self, R:int, d_node:int, K_coarse:int, hidden:int=256, latent:int=64):
        super().__init__()
        self.R = R
        self.K = K_coarse
        self.encoder = nn.Sequential(nn.Linear(d_node, hidden), nn.ReLU(), nn.Linear(hidden, latent), nn.ReLU())
        # coarse learned A_coarse param (we let the model predict coarse by a small head instead of free param)
        self.coarse_head = nn.Sequential(nn.Linear(latent, hidden), nn.ReLU(), nn.Linear(hidden, K_coarse*K_coarse))
        # per-relation small matrices for local scoring (used in local fine energy)
        self.Rmats = nn.Parameter(torch.randn(R, latent, latent) * 0.01)
        self.fine_head = nn.Sequential(nn.Linear(R, 128), nn.ReLU(), nn.Linear(128, 1))

    def forward(self, H: torch.Tensor, A_probs: torch.Tensor, C: torch.Tensor) -> torch.Tensor:
        # H: [N, M, d_node], A_probs: [R, N, N], C: [N, K]
        N, M, d = H.shape
        Hpool = H.mean(dim=1)  # [N, d_node]
        z = self.encoder(Hpool)  # [N, latent]
        # Coarse prediction from z pooled to communities
        # compute community embeddings E = C^T z  -> [K, latent] but we let coarse_head operate on mean
        E = (C.t() @ z) / (C.sum(dim=0, keepdim=True).t().clamp(min=1.0))
        # produce learned A_coarse per relation via coarse_head applied to each community embedding pair
        # simpler: produce A_coarse_pred[r] as outer product normalization of community embeddings projected
        Acoarse_pred = []
        for r in range(self.R):
            # compute relation-specific coarse logits via community embeddings
            # project E -> [K, K] via (E @ W_r @ E.T) with W_r derived from Rmats (reuse Rmats)
            W = self.Rmats[r]
            S = E @ W @ E.t()  # [K, K]
            Acoarse_pred.append(torch.sigmoid(S))
        Acoarse_pred = torch.stack(Acoarse_pred, dim=0)  # [R, K, K]

        # compute aggregated coarse from fine A_probs
        # A_hat_coarse[r] = C^T @ A_probs[r] @ C
        A_hat_coarse = torch.einsum('ni,rij,nj->rkj', C, A_probs, C)  # shape [R, K, K]
        # consistency energy
        E_consistency = ((Acoarse_pred - A_hat_coarse).pow(2)).sum()

        E_coarse_global = 0.0
        for r in range(self.R):
            Ar = Acoarse_pred[r]
            Ar_sq = Ar * Ar
            h = torch.trace(torch.matrix_exp(Ar_sq)) - Ar_sq.size(0)
            E_coarse_global = E_coarse_global + lambda_dag * h + lambda_sparse_coarse * Ar.abs().sum()

        E_coarse = E_coarse_global

        fine_terms = []
        for r in range(self.R):
            Rmat = self.Rmats[r]
            S = z @ Rmat @ z.t()
            S_sig = torch.sigmoid(S)
            diff = (A_probs[r] - S_sig).abs().mean()
            fine_terms.append(diff)
        fine_stack = torch.stack(fine_terms)
        E_fine = self.fine_head(fine_stack.unsqueeze(0)).squeeze()

        # total energy
        E_total = lambda_coarse * E_coarse + lambda_fine * E_fine + lambda_consistency * E_consistency
        return E_total

def sgld_ascent_on_A(A_init: torch.Tensor, H: torch.Tensor, ebm: nn.Module, steps: int, step_size: float, noise: float) -> torch.Tensor:
    logits = torch.logit(A_init.clamp(1e-6, 1 - 1e-6))
    logits = logits.detach().clone().to(H.device).requires_grad_(True)
    for _ in range(steps):
        probs = torch.sigmoid(logits)
        e = ebm(H, probs)
        g = torch.autograd.grad(e, logits, retain_graph=False, create_graph=False)[0]
        logits = logits + step_size * g + noise * torch.randn_like(logits)
        logits = logits.detach().requires_grad_(True)
    return torch.sigmoid(logits.detach())

# ---------- contrastive InfoNCE loss ----------
def info_nce_loss(E_pos: torch.Tensor, E_negs: torch.Tensor, temp: float = 1.0) -> torch.Tensor:
    # E_pos scalar or [B], E_negs [M] or [M,B]; convert to vectors
    if E_pos.dim() == 0:
        num = torch.exp(-E_pos / temp)
        den = num + torch.exp(-E_negs / temp).sum()
        return -torch.log(num / den + 1e-12)
    else:
        num = torch.exp(-E_pos / temp)
        den = num + torch.exp(-E_negs / temp).sum(dim=0)
        return -torch.log(num / den + 1e-12).mean()

# ---------- training loop ----------
def train():
    H = load_H()  # [N, M, d_out]
    N, M, d_out = H.shape
    C = load_C()  # [N, K]
    K = C.size(1)
    H_aug = None
    if invariance_enable:
        H_aug = load_H_aug()  # [N, M, d_out] for augmented environment

    R_local = int(max(1, getattr(torch, "R", 0))) 
    R_used = R_local if R_local > 1 else 6  
    R_used = int(R_used)

    proposer = ProposerNet(R=R_used, d_out=d_out, d_rank=d_rank).to(DEVICE)
    ebm = HierarchicalEBM(R=R_used, d_node=d_out, K_coarse=K).to(DEVICE)
    ebm_ema = HierarchicalEBM(R=R_used, d_node=d_out, K_coarse=K).to(DEVICE)
    ebm_ema.load_state_dict(ebm.state_dict())

    opt_prop = torch.optim.AdamW(proposer.parameters(), lr=proposer_lr, weight_decay=1e-6)
    opt_ebm = torch.optim.AdamW(ebm.parameters(), lr=ebm_lr, weight_decay=1e-6)

    sampler_pos = LangevinSampler(steps=mcmc_steps_pos, step_size=mcmc_step_size, noise=mcmc_noise)

    tasks = ["contrastive", "proposer", "invariance", "aux"]
    log_vars = nn.ParameterDict({t: nn.Parameter(torch.tensor(0.0)) for t in tasks})
    opt_task = torch.optim.AdamW(list(log_vars.parameters()), lr=1e-3)

    H = H.to(DEVICE)
    if H_aug is not None:
        H_aug = H_aug.to(DEVICE)

    for epoch in range(1, epochs + 1):
        proposer.train()
        ebm.train()

        # forward proposer
        U, V, A_logits, A0 = proposer(H)  # A0 [R,N,N]

        # proposer Lenergy via EMA EBM
        ebm_ema.eval()
        with torch.set_grad_enabled(True):
            L_energy = ebm_ema(H, A0, C)
            if L_energy.dim() != 0:
                L_energy = L_energy.mean()

        # MCMC refine to get A_star (positive examples) - no grad to proposer through refine (stopgrad)
        with torch.no_grad():
            A_star = sampler_pos.sample(energy_fn=lambda Ap: ebm(H, Ap, C), X_init=A0)

        # Ldistill and Lsparse
        L_distill = F.l1_loss(A0, A_star.detach(), reduction="mean")
        L_sparse = A0.abs().mean()
        L_total_proposer = alpha_energy * L_energy + alpha_distill * L_distill + alpha_sparse * L_sparse

        # CONTRASTIVE LOSS for EBM
        # Positive energy at A_star
        E_pos = ebm(H, A_star.detach(), C)
        # Easy negatives: random Bernoulli graphs
        A_neg_easy = torch.bernoulli(torch.full_like(A0, 0.01))  # sparse random
        E_negs_easy = ebm(H, A_neg_easy.to(DEVICE), C)
        # assemble negatives list
        # Hard negatives: SGLD ascent from A_star (few steps)
        A_neg_hard = sgld_ascent_on_A(A_star, H, ebm, steps=sld_ascent_steps, step_size=sld_step, noise=sld_noise)
        # create batched negatives stack
        E_negs = torch.stack([E_negs_easy, ebm(H, A_neg_hard, C)])
        L_contrastive = info_nce_loss(E_pos, E_negs, temp=1.0)

        # INVARIANCE 
        if invariance_enable and H_aug is not None:
            # generate proposer outputs for augmented env (no grad to proposer for invariance teacher? spec updates encoder and ebm)
            U2, V2, A_logits2, A02 = proposer(H_aug)
            with torch.no_grad():
                A_star2 = sampler_pos.sample(energy_fn=lambda Ap: ebm(H_aug, Ap, C), X_init=A02)
            E1 = ebm(H, A_star.detach(), C)
            E2 = ebm(H_aug, A_star2.detach(), C)
            Linv_energy = ((E1 - E2).pow(2)).mean()
            # representation-level invariance: extract a causal embedding from proposer/encoder
            # heuristic: use U (or the proposer pooled outputs) as proxy Z; compute mean over relations of self rows
            Z1 = U.mean(dim=0)  # [N, d_rank]
            Z2 = U2.mean(dim=0)
            Linv_repr = F.mse_loss(Z1, Z2, reduction="mean")
            Linv = 0.1 * Linv_energy + 1.0 * Linv_repr
        else:
            Linv = torch.tensor(0.0, device=DEVICE)
            Linv_energy = torch.tensor(0.0, device=DEVICE)
            Linv_repr = torch.tensor(0.0, device=DEVICE)

        # Laux: MoE specialization & load balancing - approximate using diversity of U,V & sparsity of relation embeddings
        L_moespec = (U.pow(2).mean() + V.pow(2).mean()) * 1e-4
        # load-balance proxy: variance of mean usage across relations (minimize variance)
        # compute relation mass sum
        relation_mass = A0.sum(dim=(1,2))  # [R]
        load_bal = relation_mass.mean() / (relation_mass.std() + 1e-8)
        L_loadbalance = (1.0 / (load_bal + 1e-8))
        L_aux = L_moespec + 1e-2 * L_loadbalance

        # Compose weighted total losses with uncertainty weighting
        # For EBM (contrastive) and proposer and invariance and aux
        def weighted(l, name):
            s = log_vars[name]
            return 0.5 * torch.exp(-s) * l + 0.5 * s

        total_loss_tasks = weighted(L_contrastive, "contrastive") + weighted(L_total_proposer, "proposer") + weighted(Linv, "invariance") + weighted(L_aux, "aux")

        # Update proposer (only proposer params) using its weighted term
        opt_prop.zero_grad()
        prop_loss_for_step = weighted(L_total_proposer, "proposer")
        prop_loss_for_step.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(proposer.parameters(), max_norm=5.0)
        opt_prop.step()

        # Update EBM using contrastive weighted term
        opt_ebm.zero_grad()
        ebm_loss_for_step = weighted(L_contrastive, "contrastive")
        ebm_loss_for_step.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(ebm.parameters(), max_norm=5.0)
        opt_ebm.step()

        # Update task log_vars
        opt_task.zero_grad()
        total_loss_tasks.backward()
        opt_task.step()

        # Optionally update encoder / causal stream if you have access; here we assume H fixed.
        # EMA update of EBM critic
        for p_ema, p in zip(ebm_ema.parameters(), ebm.parameters()):
            p_ema.data.mul_(ema_tau).add_(p.data, alpha=(1.0 - ema_tau))

        if epoch % 1 == 0:
            print(f"Epoch {epoch:03d} | prop_loss {L_total_proposer.item():.6f} | contrastive {L_contrastive.item():.6f} | Linv {Linv.item():.6f} | Laux {L_aux.item():.6f}")

        if epoch % 5 == 0 or epoch == epochs:
            ck = {
                "epoch": epoch,
                "proposer_state": proposer.state_dict(),
                "ebm_state": ebm.state_dict(),
                "ebm_ema_state": ebm_ema.state_dict(),
                "log_vars": {k: v.detach().cpu() for k, v in log_vars.items()},
                "U": U.detach().cpu(),
                "V": V.detach().cpu(),
                "A0": A0.detach().cpu(),
                "A_star": A_star.detach().cpu()
            }
            SAVE_CKPT.parent.mkdir(parents=True, exist_ok=True)
            torch.save(ck, SAVE_CKPT)

    # final save
    final = {
        "proposer_state": proposer.state_dict(),
        "ebm_state": ebm.state_dict(),
        "ebm_ema_state": ebm_ema.state_dict(),
        "log_vars": {k: v.detach().cpu() for k, v in log_vars.items()}
    }
    torch.save(final, SAVE_CKPT.with_suffix(".final.pth"))
    print("Training finished. Checkpoint saved.")

if __name__ == "__main__":
    train()


### 3.4.2.6 Final Output and Inference (Reasoning with Uncertainity)

In [None]:
import math
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import torch
import torch.nn.functional as F
import numpy as np

from proposer_qphi_final import ProposerNet  # proposer implementation
from causal_energy_curriculum import HierarchicalEBM  # hierarchical EBM

# ---------- config ----------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUT = Path("out")
LAYER2_OUT = OUT / "lpa_layer2_output.pt"
LAYER1_OUT = OUT / "lpa_layer1_output.pt"
C_PATH = OUT / "C_soft.pt"
PROPOSER_CKPT = OUT / "proposer_qphi_checkpoint.pth.final.pth"
EBM_CKPT = OUT / "causal_energy_curriculum_ckpt.pth.final.pth"
ENCODER_CKPT = OUT / "lpa_model" / "encoder_finetuned.pth"
RESULTS_OUT = OUT / "inference_results_reuse.pt"

Nenc = 5
Nprop = 10
epistemic_threshold = 0.1
k_chains = 32
T_sgld = 300
burn_in = int(T_sgld * 0.2)
init_noise_std = 1e-2
sgld_step0 = 0.02
sgld_noise = 1e-3
Tmax = 5.0
seed = 20251127

torch.manual_seed(seed)
np.random.seed(seed)

# ---------- loaders ----------
def load_H() -> torch.Tensor:
    if LAYER2_OUT.exists():
        o = torch.load(LAYER2_OUT, map_location="cpu")
        k = "H2" if "H2" in o else next(iter(o.keys()))
        H = o[k]
    elif LAYER1_OUT.exists():
        o = torch.load(LAYER1_OUT, map_location="cpu")
        k = "H1" if "H1" in o else next(iter(o.keys()))
        H = o[k]
    else:
        raise FileNotFoundError("No layer outputs found in out/")
    return H.to(DEVICE)

def load_C() -> torch.Tensor:
    obj = torch.load(C_PATH, map_location="cpu")
    return obj.to(DEVICE)

def load_models() -> Tuple[ProposerNet, HierarchicalEBM]:
    ck_p = torch.load(PROPOSER_CKPT, map_location="cpu")
    proposer_state = ck_p.get("proposer") or ck_p.get("proposer_state") or ck_p
    # infer dims
    d_out = None; R = None; d_rank = None
    for k,v in proposer_state.items():
        if k.endswith("WK.weight"):
            d_out = v.size(1)
        if k.endswith("q_rel"):
            R = v.size(0)
        if k.endswith("g_shared.W2.weight"):
            d_rank = v.size(0)
    if d_out is None or R is None or d_rank is None:
        raise RuntimeError("Could not infer proposer dims from checkpoint")
    proposer = ProposerNet(R=R, d_out=d_out, d_rank=d_rank).to(DEVICE)
    proposer.load_state_dict(proposer_state, strict=True)
    ebm_ck = torch.load(EBM_CKPT, map_location="cpu")
    ebm_state = ebm_ck.get("ebm_state") or ebm_ck.get("ebm") or ebm_ck
    K = load_C().size(1)
    ebm = HierarchicalEBM(R=R, d_node=d_out, K_coarse=K).to(DEVICE)
    ebm.load_state_dict(ebm_state, strict=True)
    proposer.eval(); ebm.eval()
    return proposer, ebm

# ---------- helpers ----------
def reconstruct_A_from_factors(U: torch.Tensor, V: torch.Tensor, b_rel: Optional[torch.Tensor]=None) -> torch.Tensor:
    R, N, d = U.shape
    A_logits = torch.zeros(R, N, N, device=U.device)
    for r in range(R):
        A_logits[r] = U[r] @ V[r].t()
        if b_rel is not None:
            A_logits[r] = A_logits[r] + b_rel[r]
    return torch.sigmoid(A_logits)

def temperature_schedule_linear(Tmax: float, steps: int):
    def fn(t):
        if steps <= 1:
            return 1.0
        return Tmax - (Tmax - 1.0) * (t / float(steps - 1))
    return fn

def mc_dropout_proposer_ensemble(proposer: ProposerNet, H: torch.Tensor, Nenc:int, Nprop:int) -> Dict:
    samples_U=[]; samples_V=[]; samples_A=[]
    proposer.train()
    with torch.no_grad():
        for _ in range(Nenc):
            Hs = H
            for _ in range(Nprop):
                U,V,logits,A = proposer(Hs)
                samples_U.append(U.cpu())
                samples_V.append(V.cpu())
                samples_A.append(A.cpu())
    U_stack = torch.stack(samples_U, dim=0).to(DEVICE)
    V_stack = torch.stack(samples_V, dim=0).to(DEVICE)
    A_stack = torch.stack(samples_A, dim=0).to(DEVICE)
    return {
        "U_stack": U_stack, "V_stack": V_stack, "A_stack": A_stack,
        "U_mean": U_stack.mean(dim=0), "V_mean": V_stack.mean(dim=0),
        "A_mean": A_stack.mean(dim=0), "A_var": A_stack.var(dim=0, unbiased=False)
    }

def select_candidate_edges(A_mean: torch.Tensor, threshold: float=0.1) -> List[Tuple[int,int,int]]:
    R,N,_ = A_mean.shape
    idx = (A_mean > threshold).nonzero(as_tuple=False)
    return [(int(t[0]), int(t[1]), int(t[2])) for t in idx]

def compute_edge_stats_from_ensemble(A_stack: torch.Tensor, edges: List[Tuple[int,int,int]]):
    stats = {}
    if len(edges)==0:
        return stats
    for (r,i,j) in edges:
        vals = A_stack[:, r, i, j].cpu().numpy()
        stats[(r,i,j)] = {"mu": float(vals.mean()), "var": float(vals.var()), "ci": (float(np.percentile(vals,2.5)), float(np.percentile(vals,97.5)))}
    return stats

def sgld_factor_space(U_init: torch.Tensor, V_init: torch.Tensor, H: torch.Tensor, ebm: HierarchicalEBM,
                      C: torch.Tensor, steps:int, step_size:float, noise_scale:float, temp_fn) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    k,R,N,d = U_init.shape
    U = U_init.clone().to(DEVICE)
    V = V_init.clone().to(DEVICE)
    for t in range(steps):
        Tt = temp_fn(t)
        eta = step_size
        for c in range(k):
            Uc = U[c].detach().clone().requires_grad_(True)
            Vc = V[c].detach().clone().requires_grad_(True)
            A_logits = torch.zeros(R, N, N, device=DEVICE)
            for r in range(R):
                A_logits[r] = Uc[r] @ Vc[r].t()
            A_probs = torch.sigmoid(A_logits)
            E = ebm(H, A_probs, C)
            E_scaled = E / Tt
            grad_U, grad_V = torch.autograd.grad(E_scaled, (Uc, Vc), retain_graph=False, create_graph=False)
            noise_U = torch.randn_like(Uc) * math.sqrt(2.0 * eta) * noise_scale
            noise_V = torch.randn_like(Vc) * math.sqrt(2.0 * eta) * noise_scale
            Uc_next = Uc - eta * grad_U + noise_U
            Vc_next = Vc - eta * grad_V + noise_V
            U[c] = Uc_next.detach()
            V[c] = Vc_next.detach()
    A_final = torch.zeros(k, R, N, N, device=DEVICE)
    for c in range(k):
        A_final[c] = reconstruct_A_from_factors(U[c], V[c])
    return U, V, A_final

# ---------- main ----------
def main():
    H = load_H()
    C = load_C()
    proposer, ebm = load_models()
    proposer.to(DEVICE); ebm.to(DEVICE)
    proposer.eval(); ebm.eval()

    mc = mc_dropout_proposer_ensemble(proposer, H, Nenc=Nenc, Nprop=Nprop)
    U_stack = mc["U_stack"]; V_stack = mc["V_stack"]; A_stack = mc["A_stack"]
    U_mean = mc["U_mean"]; V_mean = mc["V_mean"]; A_mean = mc["A_mean"]

    edges = select_candidate_edges(A_mean, threshold=epistemic_threshold)
    epistemic_stats = compute_edge_stats_from_ensemble(A_stack, edges)

    R, N, d_rank = U_mean.shape
    U0 = U_mean.detach().clone()
    V0 = V_mean.detach().clone()
    U_chains = U0.unsqueeze(0).expand(k_chains, -1, -1, -1).clone() + torch.randn(k_chains, R, N, d_rank, device=DEVICE) * init_noise_std
    V_chains = V0.unsqueeze(0).expand(k_chains, -1, -1, -1).clone() + torch.randn(k_chains, R, N, d_rank, device=DEVICE) * init_noise_std

    temp_fn = temperature_schedule_linear(Tmax, T_sgld)
    U_final, V_final, A_final = sgld_factor_space(U_chains, V_chains, H, ebm, C, steps=T_sgld, step_size=sgld_step0, noise_scale=sgld_noise, temp_fn=temp_fn)

    A_final_np = A_final.cpu().numpy()
    struct_stats = {}
    for (r,i,j) in edges:
        vals = A_final_np[:, r, i, j]
        struct_stats[(r,i,j)] = {"mu": float(vals.mean()), "var": float(vals.var()), "ci": (float(np.percentile(vals,2.5)), float(np.percentile(vals,97.5)))}

    combined = {}
    for e in edges:
        ep = epistemic_stats.get(e, {"mu":None,"var":0.0})
        st = struct_stats.get(e, {"mu":None,"var":0.0})
        combined[e] = {
            "epistemic_mu": ep.get("mu"),
            "epistemic_var": ep.get("var",0.0),
            "structural_mu": st.get("mu"),
            "structural_var": st.get("var",0.0),
            "combined_var": float(ep.get("var",0.0) + st.get("var",0.0))
        }

    out = {
        "A_mean": A_mean.cpu(),
        "A_var_epistemic": mc["A_var"].cpu(),
        "epistemic_edges": epistemic_stats,
        "structural_edges": struct_stats,
        "combined_edges": combined,
        "A_chains_final": A_final.cpu(),
        "U_chains_final": U_final.cpu(),
        "V_chains_final": V_final.cpu()
    }
    RESULTS_OUT.parent.mkdir(parents=True, exist_ok=True)
    torch.save(out, RESULTS_OUT)
    print(f"Saved inference results to {RESULTS_OUT}")

if __name__ == "__main__":
    main()
