In [1]:
!pip install gcastle torch numpy pandas networkx scikit-learn

Collecting gcastle
  Downloading gcastle-1.0.4-py3-none-any.whl.metadata (7.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting n

# New Section

In [40]:
# -*- coding: utf-8 -*-
"""
Causal-Discovery (DQN vs. GraN-DAG) with Copula/ Gaussian BIC, CAM pruning, live metrics

- Fast Gaussian BIC and Gaussian Copula (nonparanormal) BIC
- Greedy warm-start (beam + reversals)
- Double DQN + Polyak target updates
- CAM pruning (linear regression) post-training
- SHD/FDR/TPR/etc printed every eval, plus best-by-ValBIC and best-by-TPR snapshots

Training does NOT use GT. GT (if provided) is used ONLY for evaluation.
"""

import os, random, warnings
from collections import deque

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import networkx as nx

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from scipy.stats import rankdata, norm

# Opponent
from castle.algorithms import GraNDAG

warnings.filterwarnings("ignore", category=UserWarning)

# -------------------- Config --------------------
DATA_CSV   = "/content/data.csv"
GT_NPY     = "/content/adj.npy"    # optional; evaluation only
G_ITER     = 1000                  # GraN-DAG iterations

N_EPISODES = 400
EVAL_EVERY = 20
SEED       = 42

EDGE_BUDGET_RATIO = 1.1           # warm-start cap (~1.1 * p edges)
LAMBDA_L1         = 0.02          # sparsity penalty in reward
ACTION_COST       = 0.05          # small penalty per committed edit
CAM_TH            = 0.3           # CAM pruning threshold

SCORE_TYPE = "copula"             # "copula" (robust) or "gaussian"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE  = torch.double
np.random.seed(SEED); random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


# -------------------- IO --------------------
def load_data(csv_path=DATA_CSV):
    if not os.path.exists(csv_path):
        print(f"Error: '{csv_path}' not found."); return None
    df = pd.read_csv(csv_path, header=0)
    if df.columns[0].lower().startswith("unnamed"):
        df = pd.read_csv(csv_path, header=0, index_col=0)
    df = df.apply(pd.to_numeric, errors="coerce").dropna(axis=1, how="all")
    X = df.values.astype(np.float64)
    mu, sd = X.mean(0, keepdims=True), X.std(0, keepdims=True); sd[sd == 0] = 1.0
    X = (X - mu) / sd
    print(f"Loaded data: {X.shape[0]} samples, {X.shape[1]} vars")
    return X

def load_truth(npy_path=GT_NPY, p=None):
    if not os.path.exists(npy_path):
        print("No ground truth file; metrics will be limited.")
        return None
    G = np.load(npy_path).astype(np.float64)
    if p is not None and G.shape != (p, p):
        print(f"[align] trimming GT from {G.shape} to {(p,p)}")
        G = G[:p, :p]
    print("Loaded ground truth:", G.shape)
    return G


# -------------------- Opponent (GraN-DAG) --------------------
def get_grandag_adj(data, iterations=G_ITER, hidden_dim=16, hidden_num=2, lr=5e-4, h_threshold=1e-6):
    print(f"\nRunning GraN-DAG (iterations={iterations})...")
    model = GraNDAG(
        input_dim=data.shape[1], hidden_dim=hidden_dim, hidden_num=hidden_num,
        lr=lr, iterations=iterations, h_threshold=h_threshold, mu_init=1e-3
    )
    model.learn(data)
    print("GraN-DAG done.")
    return np.array(model.causal_matrix, dtype=np.float64)


# -------------------- Metrics --------------------
def binarize(A):
    A = (A != 0).astype(int)
    np.fill_diagonal(A, 0)
    return A

def shd_binary(A, B):
    A = binarize(A); B = binarize(B)
    Au = ((A + A.T) > 0).astype(int); Bu = ((B + B.T) > 0).astype(int)
    undirected_diff = int(np.sum(np.triu(Au ^ Bu, 1)))
    common_u = ((Au & Bu) > 0).astype(int)
    orient_mismatch = int(np.sum(np.triu((A ^ B) & common_u, 1)))
    return undirected_diff + orient_mismatch

def eval_against_gt(pred_adj, GT):
    if GT is None or pred_adj.shape != GT.shape:
        return None
    P = binarize(pred_adj); T = binarize(GT)
    tp = int(((P == 1) & (T == 1)).sum())
    fp = int(((P == 1) & (T == 0)).sum())
    fn = int(((P == 0) & (T == 1)).sum())
    tn = int(((P == 0) & (T == 0)).sum())
    fdr = fp / max(tp + fp, 1)
    tpr = tp / max(tp + fn, 1)
    fpr = fp / max(fp + tn, 1)
    return {
        "total_edges_original":  int(T.sum()),
        "total_edges_predicted": int(P.sum()),
        "correct_edges":         tp,
        "fdr": round(fdr, 4),
        "tpr": round(tpr, 4),
        "fpr": round(fpr, 4),
        "shd": shd_binary(P, T),
        "nnz": int(P.sum()),
    }


# -------------------- Scorers --------------------
class GaussianBIC:
    """Gaussian BIC via covariance blocks (no per-step OLS)."""
    def __init__(self, X):
        X = np.asarray(X, dtype=np.float64)
        self.X = X
        self.n, self.p = X.shape
        self.S = X.T @ X
        self.yTy = np.diag(self.S)

    def node_rss(self, parents, j):
        if len(parents) == 0:
            y = self.X[:, j]
            return float(((y - y.mean()) ** 2).sum())
        P = np.array(parents, dtype=int)
        Spp = self.S[np.ix_(P, P)]
        Spy = self.S[P, j]
        try:
            coef = np.linalg.solve(Spp, Spy)
        except np.linalg.LinAlgError:
            coef = np.linalg.pinv(Spp) @ Spy
        rss = self.yTy[j] - Spy @ coef
        return float(max(rss, 1e-12))

    def bic(self, A):
        n, p = self.n, self.p
        total_rss = 0.0
        num_params = p  # intercepts
        for j in range(p):
            parents = np.where(A[:, j] == 1)[0].tolist()
            total_rss += self.node_rss(parents, j)
            num_params += len(parents)
        dof = max(n - num_params, 1)
        sigma2 = max(total_rss / dof, 1e-9)
        loglik = -0.5 * n * (p * np.log(2 * np.pi * sigma2) + 1.0)
        return float(loglik - 0.5 * num_params * np.log(n))

def gaussian_copula_transform(X):
    """
    Nonparanormal transform: map each marginal to ~N(0,1) by rank -> Phi^-1(u).
    """
    X = np.asarray(X, dtype=np.float64)
    n, p = X.shape
    Z = np.empty_like(X)
    eps = 1e-6
    rng = np.random.default_rng(0)
    for j in range(p):
        x = X[:, j]
        if np.std(x) < 1e-12:
            x = x + rng.normal(0, 1e-9, size=n)
        r = rankdata(x, method="average")  # 1..n
        u = (r - 0.5) / n                  # (0,1)
        u = np.clip(u, eps, 1 - eps)
        Z[:, j] = norm.ppf(u)
    Z -= Z.mean(axis=0, keepdims=True)
    std = Z.std(axis=0, keepdims=True); std[std == 0] = 1.0
    Z /= std
    return Z

class CopulaBIC(GaussianBIC):
    """Gaussian Copula (Nonparanormal) BIC scorer."""
    def __init__(self, X):
        Z = gaussian_copula_transform(X)
        super().__init__(Z)


# -------------------- Warm-start (greedy + reversals) --------------------
def warm_start_greedy_bic(Xtr, Xva, edge_budget, scorer_cls=GaussianBIC,
                          max_passes=5, topk_per_pass=5, restarts=5, seed=SEED):
    rng = np.random.RandomState(seed)
    p = Xtr.shape[1]
    fb = scorer_cls(Xva)

    def is_dag(M): return nx.is_directed_acyclic_graph(nx.DiGraph(M))

    def one_run():
        A = np.zeros((p, p), dtype=np.float64)
        best = fb.bic(A)
        for _ in range(max_passes):
            improved = False
            # forward: add Top-K
            cands = []
            if A.sum() < edge_budget:
                for i in range(p):
                    for j in range(p):
                        if i == j or A[i, j] == 1: continue
                        if A.sum() >= edge_budget: break
                        trial = A.copy(); trial[i, j] = 1.0; np.fill_diagonal(trial, 0.0)
                        if not is_dag(trial): continue
                        s = fb.bic(trial)
                        if s > best: cands.append((s - best, i, j))
                cands.sort(reverse=True, key=lambda x: x[0])
                for _, i, j in cands[:topk_per_pass]:
                    if A.sum() >= edge_budget: break
                    trial = A.copy(); trial[i, j] = 1.0; np.fill_diagonal(trial, 0.0)
                    if not is_dag(trial): continue
                    s = fb.bic(trial)
                    if s > best:
                        A, best, improved = trial, s, True
            # backward: prune
            pruned = True
            while pruned:
                pruned = False
                for i in range(p):
                    for j in range(p):
                        if A[i, j] == 0: continue
                        trial = A.copy(); trial[i, j] = 0.0
                        s = fb.bic(trial)
                        if s > best:
                            A, best, improved, pruned = trial, s, True, True
            # reversal sweep
            for i in range(p):
                for j in range(p):
                    if A[i, j] != 1: continue
                    trial = A.copy(); trial[i, j] = 0.0; trial[j, i] = 1.0
                    if not is_dag(trial): continue
                    s = fb.bic(trial)
                    if s > best:
                        A, best, improved = trial, s, True
            if not improved: break
        return A, best

    bestA, bestS = None, -np.inf
    for _ in range(restarts):
        A, s = one_run()
        if s > bestS: bestA, bestS = A, s
    return bestA


# -------------------- Environment --------------------
class CausalDiscoveryEnv:
    """
    Reward = Δ Val-BIC - λ1 * edges - action_cost - small step penalty.
    """
    def __init__(self, data, grandag_adj,
                 val_frac=0.2,
                 edge_budget_ratio=EDGE_BUDGET_RATIO,
                 lambda_l1=LAMBDA_L1, action_cost=ACTION_COST,
                 warm_start=True,
                 score_type=SCORE_TYPE):
        self.full = data
        self.n_samples, self.n_nodes = data.shape
        self.grandag_adj = grandag_adj.astype(np.float64)

        idx = np.arange(self.n_samples); rng = np.random.RandomState(SEED)
        rng.shuffle(idx)
        cut = int((1.0 - val_frac) * self.n_samples)
        self.Xtr, self.Xva = self.full[idx[:cut]], self.full[idx[cut:]]

        Scorer = CopulaBIC if str(score_type).lower() == "copula" else GaussianBIC
        self.scorer_cls = Scorer
        self.bic_va = Scorer(self.Xva)

        self.state_space_shape = (self.n_nodes * self.n_nodes,)
        self.n_actions = 3 * self.n_nodes * (self.n_nodes - 1)
        self.action_map = self._create_action_map()

        self.current_adj = np.zeros((self.n_nodes, self.n_nodes), dtype=np.float64)
        self.max_steps = 10 * self.n_nodes
        self.current_step = 0

        self.edge_budget = int(max(1, edge_budget_ratio * self.n_nodes))
        self.lambda_l1 = float(lambda_l1)
        self.action_cost = float(action_cost)

        self._warm_adj = None
        if warm_start:
            print("\n[Warm-start] searching ...")
            self._warm_adj = warm_start_greedy_bic(
                self.Xtr, self.Xva,
                edge_budget=self.edge_budget,
                scorer_cls=self.scorer_cls,
                max_passes=5, topk_per_pass=5, restarts=5, seed=SEED
            )
            print("[Warm-start] edges:", int(np.sum(self._warm_adj)))

    def _create_action_map(self):
        mapping, idx = {}, 0
        for i in range(self.n_nodes):
            for j in range(self.n_nodes):
                if i == j: continue
                mapping[idx] = ("add", i, j); idx += 1
                mapping[idx] = ("remove", i, j); idx += 1
                mapping[idx] = ("reverse", i, j); idx += 1
        return mapping

    def _val_bic(self, A): return self.bic_va.bic(A)

    def reset(self):
        self.current_step = 0
        if self._warm_adj is not None: self.current_adj = self._warm_adj.copy()
        else: self.current_adj[:] = 0.0
        return self.current_adj.flatten().copy()

    def step(self, action_idx):
        op, i, j = self.action_map[action_idx]
        prev_adj = self.current_adj.copy()

        if op == "add" and np.sum(self.current_adj) >= self.edge_budget:
            self.current_step += 1
            return self.current_adj.flatten(), -0.2, self.current_step >= self.max_steps, {}

        trial = self.current_adj.copy()
        if op == "add":
            trial[i, j] = 1.0
        elif op == "remove":
            trial[i, j] = 0.0
        elif op == "reverse":
            if self.current_adj[i, j] == 1.0:
                trial[i, j] = 0.0; trial[j, i] = 1.0
            else:
                self.current_step += 1
                return self.current_adj.flatten(), -0.2, self.current_step >= self.max_steps, {}
        np.fill_diagonal(trial, 0.0)

        if not nx.is_directed_acyclic_graph(nx.DiGraph(trial)):
            self.current_step += 1
            return self.current_adj.flatten(), -0.5, self.current_step >= self.max_steps, {}

        self.current_adj = trial
        self.current_step += 1
        r = self._reward(prev_adj, self.current_adj)
        done = self.current_step >= self.max_steps
        return self.current_adj.flatten().copy(), r, done, {}

    def _reward(self, prev_adj, new_adj):
        prev = self._val_bic(prev_adj)
        new  = self._val_bic(new_adj)
        score = (new - prev) / max(self.n_nodes, 1)
        score = float(np.clip(score, -100.0, 10.0))
        step_pen = -0.002
        sparsity = - self.lambda_l1 * float(np.sum(new_adj))
        act_pen = - self.action_cost if (new_adj != prev_adj).any() else 0.0
        total = score + sparsity + act_pen + step_pen
        return float(np.clip(total, -100.0, 20.0))


# -------------------- Agent (Double DQN + Polyak) --------------------
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(state_size, 256, dtype=DTYPE),
            nn.ReLU(),
            nn.Linear(256, 256, dtype=DTYPE),
            nn.ReLU(),
            nn.Linear(256, action_size, dtype=DTYPE),
        )
    def forward(self, x): return self.layers(x)

class CausalAgent:
    def __init__(self, state_size, action_size):
        self.action_size = action_size
        self.q = QNetwork(state_size, action_size).to(device)
        self.t = QNetwork(state_size, action_size).to(device)
        self.t.load_state_dict(self.q.state_dict())
        self.tau = 0.005
        self.opt = optim.Adam(self.q.parameters(), lr=5e-4)
        self.gamma = 0.95

        self.eps_start, self.eps_end = 1.0, 0.05
        self.eps_decay_steps = 250_000
        self.total_steps = 0
        self.epsilon = self.eps_start

        self.mem = deque(maxlen=120_000)
        self.batch = 256

    def _update_eps(self):
        self.total_steps += 1
        frac = min(1.0, self.total_steps / self.eps_decay_steps)
        self.epsilon = self.eps_start + frac * (self.eps_end - self.eps_start)

    def remember(self, s, a, r, ns, d): self.mem.append((s, a, r, ns, d))

    def act(self, state):
        self._update_eps()
        if random.random() <= self.epsilon:
            return random.randrange(self.action_size)
        st = torch.tensor(state, dtype=DTYPE, device=device).unsqueeze(0)
        with torch.no_grad():
            qv = self.q(st)
        return int(qv.argmax(dim=1).item())

    def replay(self):
        if len(self.mem) < self.batch: return
        batch = random.sample(self.mem, self.batch)
        s  = torch.tensor(np.array([e[0] for e in batch]), dtype=DTYPE, device=device)
        a  = torch.tensor([e[1] for e in batch], dtype=torch.long, device=device).unsqueeze(1)
        r  = torch.tensor([e[2] for e in batch], dtype=DTYPE, device=device).unsqueeze(1)
        ns = torch.tensor(np.array([e[3] for e in batch]), dtype=DTYPE, device=device)
        d  = torch.tensor([e[4] for e in batch], dtype=DTYPE, device=device).unsqueeze(1)

        q_sa = self.q(s).gather(1, a)
        with torch.no_grad():
            na_online = self.q(ns).argmax(1, keepdim=True)
            q_next = self.t(ns).gather(1, na_online)
            target = r + (1.0 - d) * self.gamma * q_next

        loss = nn.MSELoss()(q_sa, target)
        self.opt.zero_grad(); loss.backward()
        nn.utils.clip_grad_norm_(self.q.parameters(), 5.0)
        self.opt.step()

        with torch.no_grad():
            for tparam, qparam in zip(self.t.parameters(), self.q.parameters()):
                tparam.data.mul_(1 - self.tau).add_(self.tau * qparam.data)


# -------------------- CAM pruning --------------------
def graph_prunned_by_coef(parent_mat, X, th=CAM_TH):
    """
    parent_mat[i, j]=1 means j -> i (parents per row).
    Returns binary parent matrix via linear regression thresholding.
    """
    d = parent_mat.shape[0]
    reg = LinearRegression()
    W = []
    for i in range(d):
        col = np.abs(parent_mat[i]) > 0.1
        if np.sum(col) <= 0:
            W.append(np.zeros(d)); continue
        X_train = X[:, col]
        y = X[:, i]
        reg.fit(X_train, y)
        coeff = reg.coef_
        new_coeff = np.zeros(d)
        cj = 0
        for ci in range(d):
            if col[ci]:
                new_coeff[ci] = coeff[cj]; cj += 1
        W.append(new_coeff)
    return (np.abs(np.vstack(W)) > th).astype(np.float64)

def cam_prune_linear_from_A(A_directed, X, th=CAM_TH):
    # Our A: i->j. Convert to parent matrix then back.
    parents = A_directed.T
    pruned_parents = graph_prunned_by_coef(parents, X, th=th)
    return pruned_parents.T


# -------------------- Main --------------------
if __name__ == "__main__":
    X = load_data(DATA_CSV)
    if X is None: raise SystemExit
    p = X.shape[1]
    GT = load_truth(GT_NPY, p=p)

    # Opponent
    Gdag = get_grandag_adj(X, iterations=G_ITER)
    if Gdag.shape != (p, p):
        Gdag = Gdag[:p, :p]
    Gdag_bin = binarize(Gdag)

    # Env + Agent
    env = CausalDiscoveryEnv(
        X, Gdag,
        val_frac=0.2,
        edge_budget_ratio=EDGE_BUDGET_RATIO,
        lambda_l1=LAMBDA_L1, action_cost=ACTION_COST,
        warm_start=True,
        score_type=SCORE_TYPE  # "copula" or "gaussian"
    )
    agent = CausalAgent(state_size=env.state_space_shape[0], action_size=env.n_actions)

    # Track best snapshots
    best_valbic = -1e18
    best_agent_adj = None
    best_tpr = -1.0
    best_agent_adj_by_tpr = None

    # Training loop
    print("\nStarting training...")
    for ep in range(N_EPISODES):
        s = env.reset()
        total = 0.0; done = False
        while not done:
            a = agent.act(s)
            ns, r, done, _ = env.step(a)
            agent.remember(s, a, r, ns, done)
            s = ns; total += r
            agent.replay()

        msg = f"Ep {ep+1:04d}/{N_EPISODES}  Reward={total:8.3f}  eps={agent.epsilon:.3f}"

        if (ep + 1) % EVAL_EVERY == 0:
            A_now = binarize(env.current_adj)
            valBIC_agent = env._val_bic(A_now)
            valBIC_gran  = env._val_bic(Gdag_bin)

            # keep best-by-ValBIC (no GT)
            if valBIC_agent > best_valbic:
                best_valbic = valBIC_agent
                best_agent_adj = A_now.copy()

            # CAM (eval only)
            A_cam = cam_prune_linear_from_A(A_now, X, th=CAM_TH).astype(int)

            msg += f" | ValBIC(A)={valBIC_agent:.1f}  ValBIC(G)={valBIC_gran:.1f}"

            if GT is not None and GT.shape == A_now.shape:
                met_A  = eval_against_gt(A_now, GT)
                met_AC = eval_against_gt(A_cam, GT)
                met_G  = eval_against_gt(Gdag_bin, GT)

                # keep best-by-TPR snapshot (analysis only)
                if met_A["tpr"] > best_tpr:
                    best_tpr = met_A["tpr"]
                    best_agent_adj_by_tpr = A_now.copy()

                msg += (f"\n   A(raw): {met_A}"
                        f"\n   A+CAM: {met_AC}"
                        f"\n   G-DAG: {met_G}")
        print(msg)

    print("\nTraining finished.")
    A_final = binarize(env.current_adj)
    A_cam_final = cam_prune_linear_from_A(A_final, X, th=CAM_TH).astype(int)

    # Final reporting
    print("\n--- Final (Validation) ---")
    print("ValBIC(agent):    ", env._val_bic(A_final))
    print("ValBIC(agent+CAM):", env._val_bic(A_cam_final))
    print("ValBIC(GraN-DAG): ", env._val_bic(Gdag_bin))

    if GT is not None and GT.shape == A_final.shape:
        print("\n--- Final (GT) Metrics ---")
        print("Agent (raw): ", eval_against_gt(A_final, GT))
        print("Agent+CAM:  ", eval_against_gt(A_cam_final, GT))
        print("GraN-DAG:   ", eval_against_gt(Gdag_bin, GT))

        if best_agent_adj is not None:
            print("\n--- Best-by-Validation snapshot ---")
            print("A(best ValBIC): ", eval_against_gt(best_agent_adj, GT))

        if best_agent_adj_by_tpr is not None:
            print("\n--- Best-by-TPR snapshot (analysis only) ---")
            print(f"Best TPR = {best_tpr:.4f}")
            print("A(best TPR): ", eval_against_gt(best_agent_adj_by_tpr, GT))


Loaded data: 10000 samples, 37 vars
Loaded ground truth: (37, 37)

Running GraN-DAG (iterations=1000)...


Training Iterations: 100%|██████████| 1000/1000 [00:26<00:00, 37.88it/s]


GraN-DAG done.

[Warm-start] searching ...
[Warm-start] edges: 15

Starting training...
Ep 0001/400  Reward=-218.937  eps=0.999
Ep 0002/400  Reward=-162.700  eps=0.997
Ep 0003/400  Reward=-198.243  eps=0.996
Ep 0004/400  Reward=-177.633  eps=0.994
Ep 0005/400  Reward=-240.795  eps=0.993
Ep 0006/400  Reward=-201.221  eps=0.992
Ep 0007/400  Reward=-203.703  eps=0.990
Ep 0008/400  Reward=-175.785  eps=0.989
Ep 0009/400  Reward=-207.475  eps=0.987
Ep 0010/400  Reward=-157.953  eps=0.986
Ep 0011/400  Reward=-139.005  eps=0.985
Ep 0012/400  Reward=-194.031  eps=0.983
Ep 0013/400  Reward=-139.522  eps=0.982
Ep 0014/400  Reward=-204.042  eps=0.980
Ep 0015/400  Reward=-205.354  eps=0.979
Ep 0016/400  Reward=-214.088  eps=0.978
Ep 0017/400  Reward=-165.958  eps=0.976
Ep 0018/400  Reward=-230.213  eps=0.975
Ep 0019/400  Reward=-224.088  eps=0.973
Ep 0020/400  Reward=-199.901  eps=0.972 | ValBIC(A)=-193334.6  ValBIC(G)=-203003.2
   A(raw): {'total_edges_original': 46, 'total_edges_predicted': 40, 

# New Section11

In [44]:
%%writefile dag_discovery_nogt.py
# -*- coding: utf-8 -*-
"""
Causal-Discovery (DQN vs. GraN-DAG) — NO GT VERSION

- Opponent: GraN-DAG (iterations configurable)
- Scoring: Gaussian BIC or Gaussian-Copula (nonparanormal) BIC on held-out validation
- Warm-start: greedy add/prune/reverse by ValBIC with edge budget
- Agent: Double DQN + Polyak target updates (no GT)
- Reward: Δ ValBIC - λ1 * edges - action_cost - small step penalty (no GT)
- Early stopping: by Validation BIC ONLY (no GT)
- SAVES the adjacency with BEST ValBIC immediately to --out (and optional --ckpt)

Run:
  python dag_discovery_nogt.py --data /content/data.csv --out /content/agent_adj.npy \
         --iters 1000 --episodes 400 --eval-every 20 --score copula
"""

import argparse
import os
import random
import warnings
from collections import deque

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import networkx as nx

from sklearn.linear_model import LinearRegression
from scipy.stats import rankdata, norm

from castle.algorithms import GraNDAG

warnings.filterwarnings("ignore", category=UserWarning)

# --------------------------- Utilities ---------------------------
def set_seed(seed: int):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def load_data(csv_path: str):
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"Data file not found: {csv_path}")
    df = pd.read_csv(csv_path, header=0)
    if df.columns[0].lower().startswith("unnamed"):
        df = pd.read_csv(csv_path, header=0, index_col=0)
    df = df.apply(pd.to_numeric, errors="coerce").dropna(axis=1, how="all")
    X = df.values.astype(np.float64)
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True); sd[sd == 0] = 1.0
    X = (X - mu) / sd
    print(f"[data] Loaded: {X.shape[0]} samples, {X.shape[1]} vars")
    return X

def binarize(A):
    B = (np.asarray(A) != 0).astype(int)
    np.fill_diagonal(B, 0)
    return B


def get_grandag_adj(data, iterations=1000, hidden_dim=16, hidden_num=2, lr=5e-4, h_threshold=1e-6):
    print(f"\nRunning GraN-DAG (iterations={iterations})...")
    model = GraNDAG(
        input_dim=data.shape[1], hidden_dim=hidden_dim, hidden_num=hidden_num,
        lr=lr, iterations=iterations, h_threshold=h_threshold, mu_init=1e-3
    )
    model.learn(data)
    print("GraN-DAG done.")
    return np.array(model.causal_matrix, dtype=np.float64)



# # --------------------------- Opponent ---------------------------
# def get_grandag_adj(data,
#                     iterations=1000,
#                     hidden_dim=16,
#                     hidden_num=2,
#                     lr=5e-4,
#                     h_threshold=1e-6):
#     device_type = "gpu" if torch.cuda.is_available() else "cpu"
#     print(f"[GraN-DAG] iters={iterations}, device={device_type}")
#     model = GraNDAG(
#         input_dim=data.shape[1],
#         hidden_dim=hidden_dim,
#         hidden_num=hidden_num,
#         lr=lr,
#         iterations=iterations,
#         h_threshold=h_threshold,
#         mu_init=1e-3,
#         model_name="NonLinGauss",
#         nonlinear="leaky-relu",
#         optimizer="sgd",
#         norm_prod="paths",
#         device_type=device_type,
#     )
#     model.learn(data)
#     G = np.array(model.causal_matrix, dtype=np.float64)
#     print("[GraN-DAG] done.")
#     return G

# --------------------------- Scorers ---------------------------
class GaussianBIC:
    def __init__(self, X: np.ndarray):
        X = np.asarray(X, dtype=np.float64)
        self.X = X
        self.n, self.p = X.shape
        self.S = X.T @ X
        self.yTy = np.diag(self.S)

    def node_rss(self, parents, j):
        if len(parents) == 0:
            y = self.X[:, j]
            return float(((y - y.mean()) ** 2).sum())
        P = np.array(parents, dtype=int)
        Spp = self.S[np.ix_(P, P)]
        Spy = self.S[P, j]
        try:
            coef = np.linalg.solve(Spp, Spy)
        except np.linalg.LinAlgError:
            coef = np.linalg.pinv(Spp) @ Spy
        rss = self.yTy[j] - Spy @ coef
        return float(max(rss, 1e-12))

    def bic(self, A: np.ndarray):
        n, p = self.n, self.p
        A = (A != 0).astype(int)
        total_rss = 0.0
        num_params = p
        for j in range(p):
            parents = np.where(A[:, j] == 1)[0].tolist()
            total_rss += self.node_rss(parents, j)
            num_params += len(parents)
        dof = max(n - num_params, 1)
        sigma2 = max(total_rss / dof, 1e-9)
        loglik = -0.5 * n * (p * np.log(2 * np.pi * sigma2) + 1.0)
        return float(loglik - 0.5 * num_params * np.log(n))

def gaussian_copula_transform(X):
    X = np.asarray(X, dtype=np.float64)
    n, p = X.shape
    Z = np.empty_like(X)
    eps = 1e-6
    rng = np.random.default_rng(0)
    for j in range(p):
        x = X[:, j]
        if np.std(x) < 1e-12:
            x = x + rng.normal(0, 1e-9, size=n)
        r = rankdata(x, method="average")
        u = (r - 0.5) / n
        u = np.clip(u, eps, 1 - eps)
        Z[:, j] = norm.ppf(u)
    Z -= Z.mean(axis=0, keepdims=True)
    std = Z.std(axis=0, keepdims=True); std[std == 0] = 1.0
    Z /= std
    return Z

class CopulaBIC(GaussianBIC):
    def __init__(self, X):
        Z = gaussian_copula_transform(X)
        super().__init__(Z)

# --------------------------- Warm-start ---------------------------
def warm_start_greedy_bic(Xtr, Xva, edge_budget, scorer_cls=GaussianBIC,
                          max_passes=5, topk_per_pass=5, restarts=5, seed=42):
    p = Xtr.shape[1]
    fb = scorer_cls(Xva)

    def is_dag(M): return nx.is_directed_acyclic_graph(nx.DiGraph(M))

    def one_run():
        A = np.zeros((p, p), dtype=np.float64)
        best = fb.bic(A)
        for _ in range(max_passes):
            improved = False
            # forward (top-k adds)
            cands = []
            if A.sum() < edge_budget:
                for i in range(p):
                    for j in range(p):
                        if i == j or A[i, j] == 1: continue
                        if A.sum() >= edge_budget: break
                        T = A.copy(); T[i, j] = 1.0; np.fill_diagonal(T, 0.0)
                        if not is_dag(T): continue
                        s = fb.bic(T)
                        if s > best: cands.append((s - best, i, j))
                cands.sort(reverse=True, key=lambda x: x[0])
                for _, i, j in cands[:topk_per_pass]:
                    if A.sum() >= edge_budget: break
                    T = A.copy(); T[i, j] = 1.0; np.fill_diagonal(T, 0.0)
                    if not is_dag(T): continue
                    s = fb.bic(T)
                    if s > best:
                        A, best, improved = T, s, True
            # backward (prune)
            pruned = True
            while pruned:
                pruned = False
                for i in range(p):
                    for j in range(p):
                        if A[i, j] == 0: continue
                        T = A.copy(); T[i, j] = 0.0
                        s = fb.bic(T)
                        if s > best:
                            A, best, improved, pruned = T, s, True, True
            # reversals
            for i in range(p):
                for j in range(p):
                    if A[i, j] != 1: continue
                    T = A.copy(); T[i, j] = 0.0; T[j, i] = 1.0
                    if not is_dag(T): continue
                    s = fb.bic(T)
                    if s > best:
                        A, best, improved = T, s, True
            if not improved: break
        return A, best

    bestA, bestS = None, -np.inf
    for _ in range(restarts):
        A, s = one_run()
        if s > bestS:
            bestA, bestS = A, s
    return (bestA != 0).astype(int)

# --------------------------- Environment ---------------------------
class CausalDiscoveryEnv:
    """
    Reward = Δ Val-BIC - λ1 * edges - action_cost - small step penalty.
    """
    def __init__(self, data, grandag_adj,
                 val_frac=0.2,
                 edge_budget_ratio=1.1,
                 lambda_l1=0.02, action_cost=0.05,
                 warm_start=True,
                 score_type="copula",
                 seed=42):
        self.full = data
        self.n_samples, self.n_nodes = data.shape
        self.grandag_adj = (grandag_adj != 0).astype(int)

        idx = np.arange(self.n_samples); rng = np.random.RandomState(seed)
        rng.shuffle(idx)
        cut = int((1.0 - val_frac) * self.n_samples)
        self.Xtr, self.Xva = self.full[idx[:cut]], self.full[idx[cut:]]

        Scorer = CopulaBIC if str(score_type).lower() == "copula" else GaussianBIC
        self.bic_va = Scorer(self.Xva)

        self.state_space_shape = (self.n_nodes * self.n_nodes,)
        self.n_actions = 3 * self.n_nodes * (self.n_nodes - 1)
        self.action_map = self._create_action_map()

        self.current_adj = np.zeros((self.n_nodes, self.n_nodes), dtype=np.float64)
        self.max_steps = 10 * self.n_nodes
        self.current_step = 0

        self.edge_budget = int(max(1, edge_budget_ratio * self.n_nodes))
        self.lambda_l1 = float(lambda_l1)
        self.action_cost = float(action_cost)

        self._warm_adj = None
        if warm_start:
            print("[warm] searching greedy init ...")
            self._warm_adj = warm_start_greedy_bic(
                self.Xtr, self.Xva,
                edge_budget=self.edge_budget,
                scorer_cls=Scorer,
                max_passes=5, topk_per_pass=5, restarts=5, seed=seed
            )
            print(f"[warm] edges={int(np.sum(self._warm_adj))}")

    def _create_action_map(self):
        mapping, idx = {}, 0
        for i in range(self.n_nodes):
            for j in range(self.n_nodes):
                if i == j: continue
                mapping[idx] = ("add", i, j); idx += 1
                mapping[idx] = ("remove", i, j); idx += 1
                mapping[idx] = ("reverse", i, j); idx += 1
        return mapping

    def _val_bic(self, A): return self.bic_va.bic(A)

    def reset(self):
        self.current_step = 0
        if self._warm_adj is not None:
            self.current_adj = self._warm_adj.copy().astype(np.float64)
        else:
            self.current_adj[:] = 0.0
        return self.current_adj.flatten().copy()

    def step(self, action_idx: int):
        op, i, j = self.action_map[action_idx]
        prev_adj = self.current_adj.copy()

        if op == "add" and np.sum(self.current_adj) >= self.edge_budget:
            self.current_step += 1
            return self.current_adj.flatten(), -0.2, self.current_step >= self.max_steps, {}

        trial = self.current_adj.copy()
        if op == "add":
            trial[i, j] = 1.0
        elif op == "remove":
            trial[i, j] = 0.0
        elif op == "reverse":
            if self.current_adj[i, j] == 1.0:
                trial[i, j] = 0.0; trial[j, i] = 1.0
            else:
                self.current_step += 1
                return self.current_adj.flatten(), -0.2, self.current_step >= self.max_steps, {}
        np.fill_diagonal(trial, 0.0)

        if not nx.is_directed_acyclic_graph(nx.DiGraph(trial)):
            self.current_step += 1
            return self.current_adj.flatten(), -0.5, self.current_step >= self.max_steps, {}

        self.current_adj = trial
        self.current_step += 1
        r = self._reward(prev_adj, self.current_adj)
        done = self.current_step >= self.max_steps
        return self.current_adj.flatten().copy(), r, done, {}

    def _reward(self, prev_adj, new_adj):
        prev = self._val_bic(prev_adj)
        new  = self._val_bic(new_adj)
        score = (new - prev) / max(self.n_nodes, 1)
        score = float(np.clip(score, -100.0, 10.0))
        step_pen = -0.002
        sparsity = - self.lambda_l1 * float(np.sum(new_adj))
        act_pen = - self.action_cost if (new_adj != prev_adj).any() else 0.0
        total = score + sparsity + act_pen + step_pen
        return float(np.clip(total, -100.0, 20.0))

# --------------------------- Agent (Double DQN) ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE  = torch.double

class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(state_size, 256, dtype=DTYPE),
            nn.ReLU(),
            nn.Linear(256, 256, dtype=DTYPE),
            nn.ReLU(),
            nn.Linear(256, action_size, dtype=DTYPE),
        )
    def forward(self, x): return self.layers(x)

class CausalAgent:
    def __init__(self, state_size, action_size, lr=5e-4, gamma=0.95,
                 eps_start=1.0, eps_end=0.05, eps_decay_steps=250_000,
                 batch=256, buffer=120_000, tau=0.005):
        self.action_size = action_size
        self.q = QNetwork(state_size, action_size).to(device)
        self.t = QNetwork(state_size, action_size).to(device)
        self.t.load_state_dict(self.q.state_dict())
        self.tau = tau
        self.opt = optim.Adam(self.q.parameters(), lr=lr)
        self.gamma = gamma

        self.eps_start, self.eps_end = eps_start, eps_end
        self.eps_decay_steps = eps_decay_steps
        self.total_steps = 0
        self.epsilon = eps_start

        self.mem = deque(maxlen=buffer)
        self.batch = batch

    def _update_eps(self):
        self.total_steps += 1
        frac = min(1.0, self.total_steps / self.eps_decay_steps)
        self.epsilon = self.eps_start + frac * (self.eps_end - self.eps_start)

    def remember(self, s, a, r, ns, d): self.mem.append((s, a, r, ns, d))

    def act(self, state):
        self._update_eps()
        if random.random() <= self.epsilon:
            return random.randrange(self.action_size)
        st = torch.tensor(state, dtype=DTYPE, device=device).unsqueeze(0)
        with torch.no_grad():
            qv = self.q(st)
        return int(qv.argmax(dim=1).item())

    def replay(self):
        if len(self.mem) < self.batch: return
        batch = random.sample(self.mem, self.batch)
        s  = torch.tensor(np.array([e[0] for e in batch]), dtype=DTYPE, device=device)
        a  = torch.tensor([e[1] for e in batch], dtype=torch.long, device=device).unsqueeze(1)
        r  = torch.tensor([e[2] for e in batch], dtype=DTYPE, device=device).unsqueeze(1)
        ns = torch.tensor(np.array([e[3] for e in batch]), dtype=DTYPE, device=device)
        d  = torch.tensor([e[4] for e in batch], dtype=DTYPE, device=device).unsqueeze(1)

        q_sa = self.q(s).gather(1, a)
        with torch.no_grad():
            na_online = self.q(ns).argmax(1, keepdim=True)
            q_next = self.t(ns).gather(1, na_online)
            target = r + (1.0 - d) * self.gamma * q_next

        loss = nn.MSELoss()(q_sa, target)
        self.opt.zero_grad(); loss.backward()
        nn.utils.clip_grad_norm_(self.q.parameters(), 5.0)
        self.opt.step()

        with torch.no_grad():
            for tparam, qparam in zip(self.t.parameters(), self.q.parameters()):
                tparam.data.mul_(1 - self.tau).add_(self.tau * qparam.data)

# --------------------------- CAM pruning (optional, not used for saving) ---------------------------
def cam_prune_linear_from_A(A_directed, X, th=0.3):
    A = (A_directed != 0).astype(int)
    parents = A.T
    d = parents.shape[0]
    reg = LinearRegression()
    W = []
    for child in range(d):
        col = parents[child] > 0
        if np.sum(col) == 0:
            W.append(np.zeros(d)); continue
        Xp = X[:, col]
        y  = X[:, child]
        reg.fit(Xp, y)
        coeff = reg.coef_
        newc = np.zeros(d)
        k = 0
        for i in range(d):
            if col[i]:
                newc[i] = coeff[k]; k += 1
        W.append(newc)
    pruned_parents = (np.abs(np.vstack(W)) > th).astype(int)
    return pruned_parents.T

# --------------------------- Main ---------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data", required=True)
    ap.add_argument("--out", required=True)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--iters", type=int, default=1000)
    ap.add_argument("--episodes", type=int, default=400)
    ap.add_argument("--eval-every", type=int, default=20)
    ap.add_argument("--score", choices=["gaussian", "copula"], default="copula")
    ap.add_argument("--edge-budget-ratio", type=float, default=1.1)
    ap.add_argument("--lambda-l1", type=float, default=0.02)
    ap.add_argument("--action-cost", type=float, default=0.05)
    ap.add_argument("--patience", type=int, default=10)
    ap.add_argument("--min-delta", type=float, default=1e-3)
    ap.add_argument("--ckpt", default=None, help="Optional extra path to also save best DAG")
    args = ap.parse_args()

    set_seed(args.seed)

    X = load_data(args.data)
    p = X.shape[1]

    # Opponent (not used for saving, just as baseline for reporting)
    Gdag = get_grandag_adj(X, iterations=args.iters)
    if Gdag.shape != (p, p):
        Gdag = Gdag[:p, :p]
    Gdag_bin = binarize(Gdag)

    # Env + Agent
    env = CausalDiscoveryEnv(
        X, Gdag,
        val_frac=0.2,
        edge_budget_ratio=args.edge_budget_ratio,
        lambda_l1=args.lambda_l1,
        action_cost=args.action_cost,
        warm_start=True,
        score_type=args.score,
        seed=args.seed
    )
    agent = CausalAgent(state_size=env.state_space_shape[0], action_size=env.n_actions)

    # Best-tracking (ValBIC only, NO GT); save immediately when improved
    best_valbic = -np.inf
    patience_ctr = 0

    print("\n[train] starting ...")
    for ep in range(args.episodes):
        s = env.reset()
        total = 0.0; done = False
        while not done:
            a = agent.act(s)
            ns, r, done, _ = env.step(a)
            agent.remember(s, a, r, ns, done)
            s = ns; total += r
            agent.replay()

        print(f"Ep {ep+1:04d}/{args.episodes} | Reward={total:9.3f} | eps={agent.epsilon:.3f}")

        if (ep + 1) % args.eval_every == 0:
            A_now = binarize(env.current_adj)
            vb_agent = env._val_bic(A_now)
            vb_cam   = env._val_bic(cam_prune_linear_from_A(A_now, X, th=0.3))
            vb_gd    = env._val_bic(Gdag_bin)

            print(f"  ValBIC | agent={vb_agent: .3f} | agent+CAM={vb_cam: .3f} | GraN-DAG={vb_gd: .3f}")

            if vb_agent > best_valbic + args.min_delta:
                best_valbic = vb_agent
                np.save(args.out, A_now)  # <-- SAVE THE BEST ADJ IMMEDIATELY
                if args.ckpt:
                    np.save(args.ckpt, A_now)
                patience_ctr = 0
                print(f"  [SAVE] Best ValBIC improved to {vb_agent:.6f} → {args.out}")
            else:
                patience_ctr += 1
                print(f"  [ES] patience {patience_ctr}/{args.patience}")

            if patience_ctr >= args.patience:
                print("  [ES] Early stopping triggered.")
                break

    print("\n[train] finished.")

    # If nothing improved (edge case), save the latest adjacency
    if not os.path.exists(args.out):
        A_final = binarize(env.current_adj)
        np.save(args.out, A_final)
        print(f"[fallback] Saved latest adjacency → {args.out}")

    print(f"[done] Best-by-ValBIC adjacency saved at: {args.out}")

if __name__ == "__main__":
    main()


Overwriting dag_discovery_nogt.py


In [None]:
!python dag_discovery_nogt.py \
  --data /content/data.csv \
  --out  /content/agent_adj.npy \
  --iters 1000 \
  --episodes 380 \
  --eval-every 20 \
  --score copula \
  --patience 10 \
  --min-delta 1e-3


2025-08-17 22:11:22,657 - /usr/local/lib/python3.11/dist-packages/castle/backend/__init__.py[line:36] - INFO: You can use `os.environ['CASTLE_BACKEND'] = backend` to set the backend(`pytorch` or `mindspore`).
2025-08-17 22:11:22,715 - /usr/local/lib/python3.11/dist-packages/castle/algorithms/__init__.py[line:36] - INFO: You are using ``pytorch`` as the backend.
[data] Loaded: 10000 samples, 37 vars

Running GraN-DAG (iterations=1000)...
2025-08-17 22:11:22,786 - /usr/local/lib/python3.11/dist-packages/castle/algorithms/gradient/gran_dag/torch/gran_dag.py[line:269] - INFO: GPU is available.
Training Iterations: 100% 1000/1000 [00:26<00:00, 38.40it/s]
GraN-DAG done.
[warm] searching greedy init ...
[warm] edges=15

[train] starting ...
Ep 0001/380 | Reward= -218.937 | eps=0.999
Ep 0002/380 | Reward= -162.700 | eps=0.997
Ep 0003/380 | Reward= -198.243 | eps=0.996
Ep 0004/380 | Reward= -177.633 | eps=0.994
Ep 0005/380 | Reward= -240.795 | eps=0.993
Ep 0006/380 | Reward= -201.221 | eps=0.99

In [29]:
import os
import pandas as pd
import numpy as np

GT_NPY = "/content/adj.npy"
DATA_CSV   = "/content/data.csv"

def load_data(csv_path=DATA_CSV):
    if not os.path.exists(csv_path):
        print(f"Error: '{csv_path}' not found."); return None
    df = pd.read_csv(csv_path, header=0)
    if df.columns[0].lower().startswith("unnamed"):
        df = pd.read_csv(csv_path, header=0, index_col=0)
    df = df.apply(pd.to_numeric, errors="coerce").dropna(axis=1, how="all")
    X = df.values.astype(np.float64)
    mu, sd = X.mean(0, keepdims=True), X.std(0, keepdims=True); sd[sd == 0] = 1.0
    X = (X - mu) / sd
    print(f"Loaded data: {X.shape[0]} samples, {X.shape[1]} vars")
    return X



def load_truth(npy_path=GT_NPY, p=None):
    if not os.path.exists(npy_path):
        print("No ground truth file; metrics will be limited.")
        return None
    G = np.load(npy_path).astype(np.float64)
    if p is not None and G.shape != (p, p):
        print(f"[align] trimming GT from {G.shape} to {(p,p)}")
        G = G[:p, :p]
    print("Loaded ground truth:", G.shape)
    return G






X = load_data(DATA_CSV)
p = X.shape[1]
GT = load_truth(GT_NPY, p=p)

Loaded data: 10000 samples, 37 vars
Loaded ground truth: (37, 37)


In [30]:
import numpy as np
A = np.load("/content/agent_adj.npy")
A.sum(), A.shape, A[:5,:5]


(np.int64(40),
 (37, 37),
 array([[0, 1, 0, 1, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0]]))

In [31]:
A

array([[0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [32]:
GT

array([[0., 1., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [33]:
import os
import time
import logging
import subprocess
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Optional


def binarize_adj(A: np.ndarray) -> np.ndarray:
    A = (A != 0).astype(int)
    np.fill_diagonal(A, 0)
    return A


def shd_binary(A: np.ndarray, B: np.ndarray) -> int:
    A = binarize_adj(A)
    B = binarize_adj(B)
    Au = ((A + A.T) > 0).astype(int)
    Bu = ((B + B.T) > 0).astype(int)
    undirected_diff = int(np.sum(np.triu(Au ^ Bu, 1)))
    common_u = ((Au & Bu) > 0).astype(int)
    orient_mismatch = int(np.sum(np.triu((A ^ B) & common_u, 1)))
    return undirected_diff + orient_mismatch


def compute_metrics(pred_adj: np.ndarray, gt_adj: np.ndarray) -> Optional[dict]:
    if gt_adj is None or pred_adj.shape != gt_adj.shape:
        return None
    P = binarize_adj(pred_adj)
    T = binarize_adj(gt_adj)
    tp = int(((P == 1) & (T == 1)).sum())
    fp = int(((P == 1) & (T == 0)).sum())
    fn = int(((P == 0) & (T == 1)).sum())
    tn = int(((P == 0) & (T == 0)).sum())
    fdr = fp / max(tp + fp, 1)
    tpr = tp / max(tp + fn, 1)
    fpr = fp / max(fp + tn, 1)
    return {
        "total_edges_gt": int(T.sum()),
        "total_edges_pred": int(P.sum()),
        "correct_edges": tp,
        "fdr": round(fdr, 4),
        "tpr": round(tpr, 4),
        "fpr": round(fpr, 4),
        "shd": shd_binary(P, T),
        "nnz": int(P.sum()),
    }


In [34]:
A_final = binarize_adj(A)
if GT is not None:
    print("\n--- Final (GT) Metrics ---")
    print("Agent (raw): ", compute_metrics(A_final, GT))



--- Final (GT) Metrics ---
Agent (raw):  {'total_edges_gt': 46, 'total_edges_pred': 40, 'correct_edges': 9, 'fdr': 0.775, 'tpr': 0.1957, 'fpr': 0.0234, 'shd': 64, 'nnz': 40}
