# Benchmark Simulation for Two-Layer (unknown)

In [None]:
# ===== Benchmark Script for Two-Layer Causal Simulation =====
# Goal: Evaluate layer-wise causal feature identification
import os
import sys
import math
import logging
import warnings
import numpy as np
import pandas as pd
from itertools import combinations
from scipy.stats import norm
from sklearn.feature_selection import mutual_info_classif
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import normalize as sk_normalize
from sklearn.preprocessing import minmax_scale
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import scanpy as sc
import ray
from ray import tune
import gc

os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

from velorama.train import *
from velorama.utils import *
from CauTrigger.utils import set_seed
from CauTrigger.model import CauTrigger2L, CauTrigger1L
from CauTrigger.dataloaders import generate_two_layer_synthetic_data

sys.path.append("../")  # 加入 CauTrigger 主目录
warnings.filterwarnings("ignore")
logging.getLogger("matplotlib.font_manager").disabled = True
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["font.sans-serif"] = ["Arial"]
plt.rcParams["font.family"] = "sans-serif"

In [None]:
def run_CauTrigger(adata, output_dir=None, mode="SHAP", topk=20, is_log1p=False, full_input=False):
    """
    Run CauTrigger in two steps:
    Step 1: Use CauTrigger1L on layer1 (closer to Y) to get top-k causal genes.
    Step 2: Use those genes as X_down in CauTrigger2L to score layer2 (upstream).
    Return:
        dict {
            "layer1": DataFrame with scores (index = layer1 gene names),
            "layer2": DataFrame with scores (index = layer2 gene names)
        }
    """
    np.random.seed(42)
    if is_log1p:
        print("CauTrigger log1p normalization applied to adata.X, adata.obsm['layer1'], and adata.obsm['layer2']")
        adata.X = np.log1p(adata.X)
        adata.obsm["layer1"] = np.log1p(adata.obsm["layer1"])
        adata.obsm["layer2"] = np.log1p(adata.obsm["layer2"])
        adata._log1p_applied = True
        print("[INFO] log1p transformation applied to adata.obsm['layer1'] and ['layer2'].")

    # === Step 1: layer1 → downstream (closer to Y) ===
    layer1_vars = adata.var_names[adata.var["layer"] == "layer1"]
    adata_layer1 = AnnData(X=adata.obsm["layer1"], obs=adata.obs.copy(), var=adata.var.loc[layer1_vars].copy())

    model_1L = CauTrigger1L(
        adata_layer1,
        n_latent=10,
        n_hidden=128,
        n_layers_encoder=0,
        n_layers_decoder=0,
        n_layers_dpd=0,
        dropout_rate_encoder=0.0,
        dropout_rate_decoder=0.0,
        dropout_rate_dpd=0.0,
        use_batch_norm="none",
        use_batch_norm_dpd=True,
        decoder_linear=False,
        dpd_linear=True,
        init_weight=None,
        init_thresh=0.4,
        attention=False,
        att_mean=False,
    )
    model_1L.train(max_epochs=200, stage_training=True, weight_scheme="sim")
    df_layer1, _ = model_1L.get_up_feature_weights(method=mode, normalize=False, sort_by_weight=True)
    print("df_layer1", df_layer1.head(20))

    # === Step 2: layer2 → upstream, use layer1 top-k as X_down ===
    layer2_vars = adata.var_names[adata.var["layer"] == "layer2"]
    df_layer1 = df_layer1.loc[layer1_vars]  # ensure order matches obsm["layer1"]
    topk_indices = df_layer1["weight"].values.argsort()[-topk:]
    X_down = adata.obsm["layer1"] if full_input else adata.obsm["layer1"][:, topk_indices]

    adata_layer2 = AnnData(
        X=adata.obsm["layer2"], obs=adata.obs.copy(), var=adata.var.loc[layer2_vars].copy(), obsm={"X_down": X_down}
    )

    model_2L = CauTrigger2L(
        adata_layer2,
        n_latent=10,
        n_hidden=128,
        n_layers_encoder=0,
        n_layers_decoder=0,
        n_layers_dpd=0,
        dropout_rate_encoder=0.0,
        dropout_rate_decoder=0.0,
        dropout_rate_dpd=0.0,
        use_batch_norm="none",
        use_batch_norm_dpd=True,
        decoder_linear=False,
        dpd_linear=True,
        init_weight=None,
        init_thresh=0.4,
        attention=False,
        att_mean=False,
    )
    model_2L.train(max_epochs=200, stage_training=True, weight_scheme="sim")
    df_layer2, _ = model_2L.get_up_feature_weights(method=mode, normalize=False, sort_by_weight=True)
    print("df_layer2", df_layer2.head(10))
    df_layer2 = df_layer2.loc[layer2_vars]

    # Set correct index for both outputs
    assert df_layer1.index.equals(layer1_vars)
    assert df_layer2.index.equals(layer2_vars)

    return {
        "layer1": df_layer1,
        "layer2": df_layer2,
    }

In [None]:
def run_PC(adata, output_dir):
    def gauss_ci_test(suff_stat, i, j, K):
        corr_matrix = suff_stat["C"]
        n_samples = suff_stat["n"]

        if len(K) == 0:
            r = corr_matrix[i, j]
        elif len(K) == 1:
            k = K[0]
            r = (corr_matrix[i, j] - corr_matrix[i, k] * corr_matrix[j, k]) / math.sqrt(
                (1 - corr_matrix[i, k] ** 2) * (1 - corr_matrix[j, k] ** 2)
            )
        else:
            sub_corr = corr_matrix[np.ix_([i, j] + K, [i, j] + K)]
            precision_matrix = np.linalg.pinv(sub_corr)
            r = (-1 * precision_matrix[0, 1]) / math.sqrt(abs(precision_matrix[0, 0] * precision_matrix[1, 1]))

        r = max(min(r, 0.99999), -0.99999)
        z = 0.5 * math.log1p((2 * r) / (1 - r))
        z_standard = z * math.sqrt(n_samples - len(K) - 3)
        p_value = 2 * (1 - norm.cdf(abs(z_standard)))

        return p_value

    def get_neighbors(G, x, exclude_y):
        return [i for i, connected in enumerate(G[x]) if connected and i != exclude_y]

    def skeleton(suff_stat, alpha):
        p_value_mat = np.zeros_like(suff_stat["C"])
        n_nodes = suff_stat["C"].shape[0]
        O = [[[] for _ in range(n_nodes)] for _ in range(n_nodes)]
        G = [[i != j for i in range(n_nodes)] for j in range(n_nodes)]
        pairs = [(i, j) for i in range(n_nodes) for j in range(i + 1, n_nodes)]

        done = False
        l = 0

        while not done and any(any(row) for row in G):
            done = True

            for x, y in pairs:
                if G[x][y]:
                    neighbors = get_neighbors(G, x, y)
                    if len(neighbors) >= l:
                        done = False
                        for K in combinations(neighbors, l):
                            p_value = gauss_ci_test(suff_stat, x, y, list(K))
                            if p_value > p_value_mat[x][y]:
                                p_value_mat[x][y] = p_value_mat[y][x] = p_value
                            if p_value >= alpha:
                                G[x][y] = G[y][x] = False
                                O[x][y] = O[y][x] = list(K)
                                break
            l += 1

        return np.asarray(G, dtype=int), O, p_value_mat

    def extend_cpdag(G, O):
        n_nodes = G.shape[0]

        def rule1(g):
            pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 0]
            for i, j in pairs:
                all_k = [
                    k for k in range(n_nodes) if (g[j][k] == 1 and g[k][j] == 1) and (g[i][k] == 0 and g[k][i] == 0)
                ]
                for k in all_k:
                    g[j][k] = 1
                    g[k][j] = 0
            return g

        def rule2(g):
            pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 1]
            for i, j in pairs:
                all_k = [
                    k for k in range(n_nodes) if (g[i][k] == 1 and g[k][i] == 0) and (g[k][j] == 1 and g[j][k] == 0)
                ]
                if len(all_k) > 0:
                    g[i][j] = 1
                    g[j][i] = 0
            return g

        def rule3(g):
            pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 1]
            for i, j in pairs:
                all_k = [
                    k for k in range(n_nodes) if (g[i][k] == 1 and g[k][i] == 1) and (g[k][j] == 1 and g[j][k] == 0)
                ]
                if len(all_k) >= 2:
                    for k1, k2 in combinations(all_k, 2):
                        if g[k1][k2] == 0 and g[k2][k1] == 0:
                            g[i][j] = 1
                            g[j][i] = 0
                            break
            return g

        pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if G[i][j] == 1]
        for x, y in sorted(pairs, key=lambda x: (x[1], x[0])):
            all_z = [z for z in range(n_nodes) if G[y][z] == 1 and z != x]
            for z in all_z:
                if G[x][z] == 0 and y not in O[x][z]:
                    G[x][y] = G[z][y] = 1
                    G[y][x] = G[y][z] = 0

        old_G = np.zeros((n_nodes, n_nodes))
        while not np.array_equal(old_G, G):
            old_G = G.copy()
            G = rule1(G)
            G = rule2(G)
            G = rule3(G)

        return np.array(G)

    def pc(suff_stat, alpha=0.5, verbose=False):
        G, O, pvm = skeleton(suff_stat, alpha)
        cpdag = extend_cpdag(G, O)
        if verbose:
            print(cpdag)
        return cpdag, pvm

    alpha = 0.05
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:
        X = np.log1p(X)
    y = adata.obs["labels"].values
    data = pd.DataFrame(np.column_stack((X, y)))
    cpdag, pvm = pc(suff_stat={"C": data.corr().values, "n": data.shape[0]}, alpha=alpha)
    pv = pvm[:-1, -1]
    arr = np.array(1 - pv).reshape(1, -1)  # 需要转换成 2D
    normalized_arr = sk_normalize(arr, norm="l1", axis=1)
    return normalized_arr.flatten()


def run_VAE(adata, output_dir):
    import torch
    from torch import nn
    from torch.utils.data import TensorDataset, DataLoader

    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:  # Rough check for count data
        X = np.log1p(X)  # Apply log1p for count data
    y = adata.obs["labels"].values
    n_features = X.shape[1]
    features = torch.tensor(X, dtype=torch.float32)
    labels = torch.tensor(y, dtype=torch.float32).view(-1, 1)
    dataset = TensorDataset(features, labels)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

    n_hidden = 5
    n_latent = 5

    class VAE(nn.Module):
        def __init__(self, num_features):
            super().__init__()

            self.encoder = nn.Sequential(
                nn.Linear(num_features, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, 2 * n_latent),
            )

            self.decoder = nn.Sequential(
                nn.Linear(n_latent, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, num_features),
                # nn.Sigmoid()
            )

            self.DPD = nn.Sequential(
                nn.Linear(n_latent, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, 1),
                nn.Sigmoid(),
            )

        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps * std + mu

        def forward(self, x):
            mu_logvar = self.encoder(x)
            mu = mu_logvar[:, :n_latent]
            logvar = mu_logvar[:, n_latent:]
            z = self.reparameterize(mu, logvar)
            y = self.DPD(z)
            reconstructed = self.decoder(z)
            return reconstructed, y, mu, logvar

    model = VAE(n_features)
    recon_criterion = nn.MSELoss()
    dpd_criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # train
    model.train()
    losses = []
    re_losses = []
    kl_losses = []
    dpd_losses = []
    for epoch in range(200):
        for data, targets in dataloader:
            optimizer.zero_grad()
            recon_batch, y_dpd, mu, logvar = model(data)
            # reconstructed loss
            re_loss = recon_criterion(recon_batch, data)
            re_losses.append(re_loss.item())

            # kl loss
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / data.shape[0]
            kl_losses.append(kl_loss.item())

            # dpd loss
            dpd_loss = dpd_criterion(y_dpd, targets)
            dpd_losses.append(dpd_loss.item())

            # total loss
            if epoch <= 100:
                loss = re_loss + kl_loss * 0.1 + dpd_loss * 0.1
            else:
                loss = re_loss + kl_loss * 0.1 + dpd_loss * 0.1

            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    model.eval()
    features.requires_grad = True
    _, y_prob, _, _ = model(features)
    loss = dpd_criterion(y_prob, labels)
    loss.backward()
    grads = features.grad.abs()
    grad_features_importance = grads.mean(dim=0)
    grad_df = grad_features_importance.detach().numpy()
    arr = np.array(grad_df).reshape(1, -1)  # 需要转换成 2D
    normalized_arr = sk_normalize(arr, norm="l1", axis=1)
    return normalized_arr.flatten()


def run_RF(adata, output_dir):
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:
        X = np.log1p(X)
    y = adata.obs["labels"].values
    rf = RandomForestClassifier()
    rf.fit(X, y)
    arr = np.array(rf.feature_importances_.flatten()).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm="l1", axis=1)
    return normalized_arr.flatten()


def run_SVM(adata, output_dir):
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:
        X = np.log1p(X)
    X = (X - X.mean(axis=0)) / X.std(axis=0)
    y = adata.obs["labels"].values
    svm = SVC(kernel="linear")
    svm.fit(X, y)
    arr = np.array(svm.coef_.flatten()).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm="l1", axis=1)
    return normalized_arr.flatten()


def run_MI(adata, output_dir):
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:
        X = np.log1p(X)
    X = (X - X.mean(axis=0)) / X.std(axis=0)
    y = adata.obs["labels"].values
    arr = np.array(mutual_info_classif(X, y).flatten()).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm="l1", axis=1)
    return normalized_arr.flatten()


def run_VELORAMA(adata, output_dir):
    target_genes = adata.var_names.tolist()
    reg_genes = adata.var_names.tolist()

    X_orig = adata.X.copy()
    if np.issubdtype(X_orig.dtype, np.integer) or X_orig.max() > 100:
        X_orig = np.log1p(X_orig)
    std = X_orig.std(0)
    std[std == 0] = 1
    X = torch.FloatTensor(X_orig - X_orig.mean(0)) / std
    X = X.to(torch.float32)

    Y_orig = adata.X.copy()
    std = Y_orig.std(0)
    std[std == 0] = 1
    Y = torch.FloatTensor(Y_orig - Y_orig.mean(0)) / std
    Y = Y.to(torch.float32)

    adata.uns["iroot"] = 0
    sc.pp.scale(adata)

    reg_target = 0
    dynamics = "pseudotime"
    ptloc = "pseudotime"
    proba = 1
    n_neighbors = 30
    velo_mode = "stochastic"
    time_series = 0
    n_comps = 20
    lag = 5
    name = "velorama_run"
    seed = 42
    hidden = 32
    penalty = "H"
    save_dir = output_dir
    lam_start = -2
    lam_end = 1
    num_lambdas = 19

    A = construct_dag(
        adata,
        dynamics=dynamics,
        ptloc=ptloc,
        proba=proba,
        n_neighbors=n_neighbors,
        velo_mode=velo_mode,
        use_time=time_series,
        n_comps=n_comps,
    )
    A = torch.FloatTensor(A)
    AX = calculate_diffusion_lags(A, X, lag)
    AY = None

    dir_name = f"{name}.seed{seed}.h{hidden}.{penalty}.lag{lag}.{dynamics}"

    if not os.path.exists(os.path.join(save_dir, dir_name)):
        os.makedirs(os.path.join(save_dir, dir_name), exist_ok=True)

    ray.init(object_store_memory=6 * 1024**3, ignore_reinit_error=True)  # 可以调大

    lam_list = np.logspace(lam_start, lam_end, num=num_lambdas).tolist()

    config = {
        "name": name,
        "AX": AX,
        "AY": AY,
        "Y": Y,
        "seed": seed,
        "lr": 0.01,
        "lam": tune.grid_search(lam_list),
        "lam_ridge": 0.0,
        "penalty": penalty,
        "lag": lag,
        "hidden": [hidden],
        "max_iter": 200,
        "device": "cpu",
        "lookback": 5,
        "check_every": 10,
        "verbose": True,
        "dynamics": dynamics,
        "results_dir": save_dir,
        "dir_name": dir_name,
        "reg_target": reg_target,
    }
    resources_per_trial = {"cpu": 1, "gpu": 0, "memory": 1 * 1024**3}
    analysis = tune.run(train_model, resources_per_trial=resources_per_trial, config=config, storage_path=save_dir)

    target_dir = os.path.join(save_dir, dir_name)
    base_dir = save_dir
    move_files(base_dir, target_dir)

    lam_list = [np.round(lam, 4) for lam in lam_list]
    all_lags = load_gc_interactions(
        name,
        save_dir,
        lam_list,
        hidden_dim=hidden,
        lag=lag,
        penalty=penalty,
        dynamics=dynamics,
        seed=seed,
        ignore_lag=False,
    )

    gc_mat = estimate_interactions(all_lags, lag=lag)  # tg_count x tf_count
    gc_df = pd.DataFrame(gc_mat.cpu().data.numpy(), index=target_genes, columns=reg_genes)
    ray.shutdown()
    return gc_df.mean(axis=0).values


def run_SCRIBE(adata, output_dir):
    from Scribe.read_export import load_anndata

    adata = adata.copy()
    sc.pp.log1p(adata)
    adata.uns["iroot"] = 0
    sc.pp.scale(adata)
    sc.pp.neighbors(adata)
    sc.tl.dpt(adata)
    adata.obs["dpt_groups"] = [
        "0" if i < adata.obs["dpt_pseudotime"].median() else "1" for i in adata.obs["dpt_pseudotime"]
    ]

    model = load_anndata(adata)
    model.rdi(
        delays=[1, 2, 3], number_of_processes=1, uniformization=False, differential_mode=False
    )  # dict_keys([1, 2, 3, 'MAX'])

    edges = []
    values = []
    for id1 in adata.var_names:
        for id2 in adata.var_names:
            if id1 == id2:
                continue
            edges.append(id1.lower() + "\t" + id2.lower())
            values.append(model.rdi_results["MAX"].loc[id1, id2])

    edges_values = [[edges[i], values[i]] for i in range(len(edges))]
    df = pd.DataFrame(edges_values, columns=["Edge", "Value"])
    df[["Source", "Target"]] = df["Edge"].str.split("\t", expand=True)
    df_sorted = df[["Source", "Target", "Value"]].sort_values(by="Value", ascending=False)
    df_mean = df_sorted.groupby("Source")["Value"].mean().reset_index()
    df_mean = df_mean.set_index("Source").loc[list(adata.var_names)]

    return np.array(df_mean["Value"])


def run_DCI(adata, output_dir):
    from causaldag import dci
    from collections import Counter
    import scipy

    adata = adata.copy()
    sc.pp.log1p(adata)

    X1 = adata.X[adata.obs["labels"] == 0].astype(float)
    X2 = adata.X[adata.obs["labels"] == 1].astype(float)
    X1 += np.random.normal(0, 1e-6, size=X1.shape)
    X2 += np.random.normal(0, 1e-6, size=X2.shape)
    p = X1.shape[1]
    if scipy.sparse.issparse(adata.X):
        X_full = adata.X.toarray()
    else:
        X_full = adata.X
    corr = np.corrcoef(X_full.T)
    threshold = 0.3  # 可调
    candidate_edges = [(i, j) for i in range(p) for j in range(i + 1, p) if abs(corr[i, j]) >= threshold]
    print(f"[INFO] Filtered candidate edges: {len(candidate_edges)}")
    difference_matrix = dci(
        X1,
        X2,
        difference_ug_method="constraint",
        difference_ug=candidate_edges,
        alpha_ug=0.05,
        alpha_skeleton=0.1,
        max_set_size=2,
    )
    ddag_edges = set(zip(*np.where(difference_matrix != 0)))
    print("len(ddag_edges)", len(ddag_edges))
    count_dict = Counter([node for edge in ddag_edges for node in edge])
    count_df = pd.DataFrame(count_dict.items(), columns=["node", "count"]).sort_values("count", ascending=False)
    full_df = pd.DataFrame({"node": adata.var_names})  # 创建完整的node列表
    count_df["node"] = count_df["node"].map(full_df["node"])
    count_df = full_df.merge(count_df, on="node", how="left").fillna({"count": 0})
    arr = np.array(count_df["count"].values).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm="l1", axis=1)
    return normalized_arr.flatten()


def run_GENIE3(adata, output_dir):
    from GENIE3 import GENIE3, get_link_list
    import tempfile
    import os

    adata = adata.copy()
    sc.pp.log1p(adata)
    adata1 = adata[adata.obs["labels"] == 1]
    adata0 = adata[adata.obs["labels"] == 0]

    with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp_file:
        tmp_path = tmp_file.name
    X = adata0.X
    gene_names = list(adata.var_names)
    VIM = GENIE3(X, gene_names=gene_names)
    get_link_list(VIM, gene_names=gene_names, file_name=tmp_path)
    df0 = pd.read_csv(tmp_path, sep="\t", header=None)
    os.remove(tmp_path)

    with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp_file:
        tmp_path = tmp_file.name
    X = adata1.X
    gene_names = list(adata.var_names)
    VIM = GENIE3(X, gene_names=gene_names)
    get_link_list(VIM, gene_names=gene_names, file_name=tmp_path)
    df1 = pd.read_csv(tmp_path, sep="\t", header=None)
    os.remove(tmp_path)

    merged = pd.merge(df1, df0, on=[0, 1], suffixes=("_net1", "_net0"))
    merged["Diff_Strength"] = abs(merged["2_net1"] - merged["2_net0"])
    diff_network = merged[[0, 1, "Diff_Strength"]]
    arr = diff_network.groupby(0)["Diff_Strength"].mean()
    arr_sorted = arr.reindex(adata.var_names)
    normalized_arr = minmax_scale(arr_sorted)

    return normalized_arr

In [None]:
def select_single_feature(score_array, threshold=None, topk=None):
    """
    Select binary labels (1 for selected, 0 for not) from a 1D score array,
    using either a threshold or top-k strategy.
    """
    score_array = np.asarray(score_array)
    n = len(score_array)

    if topk is not None:
        topk_indices = np.argsort(score_array)[-topk:]
        selected = np.zeros(n, dtype=int)
        selected[topk_indices] = 1
        return selected

    elif threshold is not None:
        return (score_array >= threshold).astype(int)

    else:
        raise ValueError("Either threshold or topk must be specified.")


def calculate_metrics(weight_df, label_col="is_causal", score_col="weight", threshold=None, topk=None):
    from sklearn.metrics import (
        roc_curve,
        auc,
        confusion_matrix,
        accuracy_score,
        matthews_corrcoef,
        f1_score,
        precision_score,
        recall_score,
        precision_recall_curve,
    )

    true_label = weight_df[label_col].values
    score_array = weight_df[score_col].values

    # Get predicted label by top-k or threshold
    pred_label = select_single_feature(score_array, threshold=threshold, topk=topk)

    # AUROC
    fpr, tpr, _ = roc_curve(true_label, score_array)
    roc_auc = auc(fpr, tpr)

    # AUPR
    precision_curve, recall_curve, _ = precision_recall_curve(true_label, score_array)
    aupr = auc(recall_curve, precision_curve)

    # Confusion matrix
    cm = confusion_matrix(true_label, pred_label)
    TN, FP = cm[0, 0], cm[0, 1]
    specificity = TN / (TN + FP) if (TN + FP) > 0 else 0.0

    # Other metrics
    acc = accuracy_score(true_label, pred_label)
    mcc = matthews_corrcoef(true_label, pred_label)
    precision = precision_score(true_label, pred_label, pos_label=1, zero_division=0)
    recall = recall_score(true_label, pred_label, pos_label=1, zero_division=0)
    f1 = f1_score(true_label, pred_label, pos_label=1, zero_division=0)

    return {
        "AUROC": roc_auc,
        "AUPR": aupr,
        "F1": f1,
        "ACC": acc,
        "MCC": mcc,
        "Precision": precision,
        "Recall": recall,
        "Specificity": specificity,
    }


def evaluate_prediction(true_label, pred_dict, topk=10):
    """
    Evaluate each prediction array in pred_dict using standard metrics.
    """
    results = []
    for submethod, pred in pred_dict.items():
        metrics = calculate_metrics(true_label, pred, topk=topk)
        results.append(
            {
                "SubMethod": submethod,
                "AUROC": metrics[0],
                "AUPR": metrics[1],
                "F1": metrics[2],
                "ACC": metrics[3],
                "MCC": metrics[4],
                "Precision": metrics[5],
                "Recall": metrics[6],
                "Specificity": metrics[7],
            }
        )
    return results


def plot_layerwise_metrics(df, output_dir, causal_strength=0.4, p_zero=0.2):
    """
    Generate Boxplot, Violinplot, and Barplot (mean±std) for each metric (AUROC, AUPR)
    with solid box color, improved clarity, and Nature Methods-compatible visuals.
    """
    import os
    import matplotlib.pyplot as plt
    import pandas as pd

    sns.set_theme(style="white")
    plt.rcParams.update(
        {
            "font.size": 14,
            "axes.labelsize": 16,
            "axes.titlesize": 18,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
            "legend.fontsize": 12,
            "figure.dpi": 300,
            "savefig.dpi": 300,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "font.family": "Arial",
        }
    )

    df = df.copy()
    df["Method"] = df["Method"].replace({"CauTrigger_SHAP": "CauTrigger", "VAEgrad": "VAE"})
    method_order = ["CauTrigger", "GENIE3", "SCRIBE", "PC", "VAE", "DCI", "MI", "RF", "SVM"]
    df["Method"] = pd.Categorical(df["Method"], categories=method_order, ordered=True)
    method_count = len(method_order)

    layer_levels = sorted(df["Layer"].dropna().unique().tolist())
    df["Layer"] = pd.Categorical(df["Layer"], categories=layer_levels, ordered=True)

    layer_palette = {
        "layer1": "#3E4A89",  # 深蓝紫 → 下游/靠近表型
        "layer2": "#74A9CF",  # 浅蓝 → 上游/调控
        "layer3": "#D9EF8B",  # 嫩黄绿 → 最上游因子
        "all": "#5B9BD5",  # 🔁 改成清爽现代蓝
    }

    metrics = ["AUROC", "AUPR"]

    for metric in metrics:
        cs_tag = f"cs{int(causal_strength * 100)}"
        p_tag = f"p{p_zero}"
        metric_tag = metric.lower()

        # === Boxplot (Optimized) ===
        plt.figure(figsize=(7, 3.5))
        ax = sns.boxplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            palette=layer_palette,
            width=0.6,
            fliersize=0,
            linewidth=0.6,
            gap=0.1,  # hue box 之间留空
            # boxprops=dict(edgecolor='#666666', linewidth=0.7),
            # whiskerprops=dict(color='#999999', linewidth=0.6),
            # capprops=dict(color='#999999', linewidth=0.6),
            # medianprops=dict(color='black', linewidth=1.2)
        )

        # # 半透明填充颜色，柔和视觉
        # for patch, (_, group) in zip(ax.patches, df.groupby(["Method", "Layer"])):
        #     layer = group["Layer"].iloc[0]
        #     face_color = layer_palette.get(layer, "#dddddd")
        #     patch.set_facecolor(face_color)
        #     patch.set_alpha(0.9)

        # 替代黑点为灰点，更柔和
        sns.stripplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            dodge=True,
            color="gray",
            size=2,
            jitter=0.25,
            alpha=0.3,
            edgecolor=None,
            linewidth=0,
            legend=False,
        )

        # 分隔线更浅更细
        for i in range(1, method_count):
            ax.axvline(i - 0.5, linestyle="--", color="lightgray", linewidth=0.3, zorder=0)

        # Legend
        handles, labels = ax.get_legend_handles_labels()
        unique_layers = list(dict.fromkeys(zip(labels, handles)))
        labels, handles = zip(*unique_layers)
        plt.legend(
            handles,
            labels,
            title="Layer",
            frameon=False,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.18),
            ncol=len(labels),
            fontsize=11,
            title_fontsize=11,
            columnspacing=1.2,
            handlelength=1.5,
        )

        plt.ylabel(metric, fontsize=13)
        plt.xlabel("")
        plt.xticks(rotation=30, ha="right", fontsize=11)
        plt.yticks(fontsize=11)
        plt.ylim(0, 1.05)
        sns.despine()
        plt.tight_layout()
        fname = f"{metric_tag}-boxplot-{cs_tag}-{p_tag}"
        plt.savefig(os.path.join(output_dir, f"{fname}.pdf"))
        plt.savefig(os.path.join(output_dir, f"{fname}.png"))
        plt.close()

        # === Violinplot ===
        plt.figure(figsize=(7, 3.5))
        ax = sns.violinplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            palette=layer_palette,
            inner="quartile",
            cut=0,
            scale="width",
            bw=0.4,
            width=0.7,
            linewidth=1.0,
            dodge=True,
            saturation=0.8,
        )
        for i in range(1, method_count):
            ax.axvline(i - 0.5, linestyle="--", color="lightgray", linewidth=0.5, zorder=0)
        plt.ylabel(metric, fontsize=13)
        plt.xlabel("")
        plt.xticks(rotation=30, ha="right", fontsize=11)
        plt.yticks(fontsize=11)
        plt.ylim(0, 1.05)
        plt.legend(
            title="Layer",
            frameon=False,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.18),
            ncol=len(set(df["Layer"])),
            fontsize=11,
            title_fontsize=11,
            columnspacing=1.2,
            handlelength=1.5,
        )
        sns.despine()
        plt.tight_layout()
        fname = f"{metric_tag}-violinplot-{cs_tag}-{p_tag}"
        plt.savefig(os.path.join(output_dir, f"{fname}.pdf"))
        plt.savefig(os.path.join(output_dir, f"{fname}.png"))
        plt.close()

        # === Barplot ===
        summary_df = df.groupby(["Method", "Layer"])[metric].agg(["mean", "std"]).reset_index()
        plt.figure(figsize=(7, 3.5))
        ax = sns.barplot(
            data=summary_df, x="Method", y="mean", hue="Layer", palette=layer_palette, errorbar=None, width=0.7
        )
        for i in range(1, method_count):
            ax.axvline(i - 0.5, linestyle="--", color="lightgray", linewidth=0.5, zorder=0)
        patches = ax.patches
        for patch, (_, row) in zip(patches, summary_df.iterrows()):
            x = patch.get_x() + patch.get_width() / 2
            y = patch.get_height()
            yerr = row["std"]
            ax.errorbar(x=x, y=y, yerr=yerr, fmt="none", ecolor="black", elinewidth=1.5, capsize=4, capthick=1.5)
        max_y = summary_df["mean"].max() + summary_df["std"].max() + 0.05
        plt.ylim(0, max(1.05, max_y))
        plt.ylabel(f"{metric} (Mean ± SD)", fontsize=13)
        plt.xlabel("")
        plt.xticks(rotation=30, ha="right", fontsize=11)
        plt.yticks(fontsize=11)
        plt.legend(
            title="Layer",
            frameon=False,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.18),
            ncol=len(set(df["Layer"])),
            fontsize=11,
            title_fontsize=11,
            columnspacing=1.2,
            handlelength=1.5,
        )
        sns.despine()
        plt.tight_layout()
        fname = f"{metric_tag}-barplot-{cs_tag}-{p_tag}"
        plt.savefig(os.path.join(output_dir, f"{fname}.pdf"))
        plt.savefig(os.path.join(output_dir, f"{fname}.png"))
        plt.close()


def plot_aggregate_layerwise_metrics(
    root_output_dir,
    causal_strength_list,
    p_zero_list,
    spurious_mode="semi_hrc",
    n_hidden=10,
    activation="linear",
    simulate_single_cell=True,
):
    """
    汇总指定参数组合下的 Layerwise_Benchmark_Metrics.csv，并生成 AUROC / AUPR 的 3x3 boxplot 总图。
    图例移至标题下方，适配 Nature Methods 风格。
    """
    import matplotlib.pyplot as plt
    import pandas as pd
    import os

    sns.set_theme(style="white")
    plt.rcParams.update(
        {
            "font.size": 12,
            "axes.labelsize": 14,
            "axes.titlesize": 16,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
            "legend.fontsize": 14,
            "figure.dpi": 300,
            "savefig.dpi": 300,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "font.family": "Arial",
        }
    )

    all_dfs = []

    for cs in causal_strength_list:
        for pz in p_zero_list:
            case_name = "_".join(
                [
                    "2L_counts",
                    spurious_mode,
                    f"hidden{n_hidden}",
                    activation,
                    f"cs{int(cs * 100):02d}",
                    f"p{pz}",
                    "sc" if simulate_single_cell else "bulk",
                ]
            )
            csv_path = os.path.join(root_output_dir, case_name, "Layerwise_Benchmark_Metrics.csv")
            if os.path.exists(csv_path):
                df = pd.read_csv(csv_path)
                df["ParamCombo"] = f"Causal Strength = {cs}, Sparsity = {pz}"
                all_dfs.append(df)

    if not all_dfs:
        print("[WARN] No matching benchmark files found.")
        return

    df = pd.concat(all_dfs, ignore_index=True)
    df["Method"] = df["Method"].replace({"CauTrigger_SHAP": "CauTrigger", "VAEgrad": "VAE"})
    df["Method"] = pd.Categorical(
        df["Method"], categories=["CauTrigger", "GENIE3", "SCRIBE", "PC", "VAE", "DCI", "MI", "RF", "SVM"], ordered=True
    )

    tag_parts = [
        "2L_counts",
        spurious_mode,
        f"hidden{n_hidden}",
        activation,
        "sc" if simulate_single_cell else "bulk",
    ]
    base_tag = "_".join(tag_parts)

    for metric in ["AUROC", "AUPR"]:
        fig, axes = plt.subplots(3, 3, figsize=(18, 12), sharey=True)
        param_combos = sorted(df["ParamCombo"].unique())

        for ax, combo in zip(axes.flatten(), param_combos):
            subdf = df[df["ParamCombo"] == combo]
            sns.boxplot(
                data=subdf,
                x="Method",
                y=metric,
                hue="Layer",
                palette={"layer1": "#3E4A89", "layer2": "#74A9CF"},
                ax=ax,
                width=0.6,
                fliersize=0,
                linewidth=1,
            )
            ax.set_title(combo, fontsize=14)
            ax.set_xlabel("")
            ax.set_ylabel(metric)
            ax.tick_params(axis="x", rotation=30)
            ax.legend_.remove()
            ax.grid(False)

            n_methods = df["Method"].nunique()
            for i in range(1, n_methods):
                ax.axvline(x=i - 0.5, linestyle="--", color="lightgray", linewidth=0.6, zorder=0)

        # 设置主标题，并调整 y 值（不再为 legend 腾出空间）
        fig.suptitle(f"Layer-wise {metric} across methods", fontsize=18, y=1.02)

        # 添加统一 legend，放在标题正下方
        handles, labels = ax.get_legend_handles_labels()
        fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=2, frameon=False)

        plt.tight_layout()
        out_prefix = f"{base_tag}_boxplot_{metric}"
        fig.savefig(os.path.join(root_output_dir, f"{out_prefix}.pdf"), bbox_inches="tight")
        fig.savefig(os.path.join(root_output_dir, f"{out_prefix}.png"), bbox_inches="tight")
        print(f"[INFO] Saved: {out_prefix}.pdf/.png")

        plt.close(fig)

In [None]:
def run_benchmark(
    algorithms, data_dir, output_dir, n_datasets=3, seed_list=None, save_adata=True, rerun=False, **generation_args
):
    """
    Run benchmark evaluation for causal discovery methods on synthetic two-layer datasets.
    Supports both single-layer methods and recursive multi-layer methods like CauTrigger.
    """
    algorithm_functions = {
        # CauTrigger scoring variants
        "CauTrigger_Model": lambda adata, out: run_CauTrigger(adata, out, mode="Model"),
        "CauTrigger_Grad": lambda adata, out: run_CauTrigger(adata, out, mode="Grad"),
        "CauTrigger_SHAP": lambda adata, out: run_CauTrigger(adata, out, mode="SHAP"),
        "CauTrigger_Ensemble": lambda adata, out: run_CauTrigger(adata, out, mode="Ensemble"),
        # Other baseline methods
        "PC": run_PC,
        "VAEgrad": run_VAE,
        "SVM": run_SVM,
        "RF": run_RF,
        "MI": run_MI,
        "DCI": run_DCI,
        # "NLBayes": run_NLBAYES,
        "GENIE3": run_GENIE3,
        # "GRNBOOST2": run_GRNBOOST2,
        "SCRIBE": run_SCRIBE,
        "VELORAMA": run_VELORAMA,
    }

    print(f"[INFO] Running benchmark on {n_datasets} datasets with algorithms: {algorithms}")

    if seed_list is None:
        seed_list = list(range(n_datasets))
    else:
        if len(seed_list) < n_datasets:
            max_seed = max(seed_list)
            seed_list += list(range(max_seed + 1, max_seed + 1 + (n_datasets - len(seed_list))))
        elif len(seed_list) > n_datasets:
            seed_list = seed_list[:n_datasets]

    os.makedirs(output_dir, exist_ok=True)

    all_results = []

    for i, seed in enumerate(seed_list):
        set_seed(seed)
        adata = generate_two_layer_synthetic_data(seed=seed, **generation_args)
        print(f"[INFO] Dataset {i + 1}: Generated with seed {seed}")

        if save_adata:
            dataset_dir = os.path.join(data_dir, f"dataset{i + 1}")
            os.makedirs(dataset_dir, exist_ok=True)
            adata.write(os.path.join(dataset_dir, "adata.h5ad"))

        for algo in algorithms:
            assert algo in algorithm_functions, f"[ERROR] Algorithm '{algo}' not registered in algorithm_functions!"

            if algo.startswith("CauTrigger"):
                # --- Check if both layer1 and layer2 results exist ---
                exist_all = True
                for layer_name in ["layer1", "layer2"]:
                    weight_path = os.path.join(output_dir, f"weights_dataset{i}_{layer_name}_{algo}.csv")
                    if not os.path.exists(weight_path):
                        exist_all = False
                        break

                if not rerun and exist_all:
                    print(f"[INFO] {algo} on dataset {i + 1} already exists. Loading...")
                    for layer_name in ["layer1", "layer2"]:
                        weight_path = os.path.join(output_dir, f"weights_dataset{i}_{layer_name}_{algo}.csv")
                        df = pd.read_csv(weight_path, index_col=0)
                        metrics = calculate_metrics(df, score_col="weight", label_col="is_causal", topk=20)
                        row = {
                            "Method": algo,
                            "Layer": layer_name,
                            "Dataset": i,
                            "Seed": seed,
                            "ScoreType": "weight",
                            **metrics,
                        }
                        all_results.append(row)
                    continue

                # --- If not exist or rerun, re-run ---
                print(f"[INFO] Evaluating {algo} on recursive 2-layer setting (Dataset {i + 1})...")
                func = algorithm_functions[algo]
                pred_dict = func(adata, output_dir)

                for layer_name, df in pred_dict.items():
                    assert "is_causal" in df.columns, (
                        f"[ERROR] 'is_causal' column missing in CauTrigger output for {layer_name}"
                    )
                    weight_path = os.path.join(output_dir, f"weights_dataset{i}_{layer_name}_{algo}.csv")
                    df.to_csv(weight_path)
                    print(f"[INFO] Saved weights to: {weight_path}")

                    metrics = calculate_metrics(df, score_col="weight", label_col="is_causal", topk=20)
                    row = {
                        "Method": algo,
                        "Layer": layer_name,
                        "Dataset": i,
                        "Seed": seed,
                        "ScoreType": "weight",
                        **metrics,
                    }
                    all_results.append(row)

            else:
                for layer_name in adata.var["layer"].unique():
                    layer_vars = adata.var_names[adata.var["layer"] == layer_name]
                    sub_adata = adata[:, layer_vars]

                    weight_path = os.path.join(output_dir, f"weights_dataset{i}_{layer_name}_{algo}.csv")
                    # if algo != "DCI" and not rerun and os.path.exists(weight_path):
                    if not rerun and os.path.exists(weight_path):
                        print(f"[INFO] {algo} on {layer_name} (Dataset {i + 1}) already exists. Loading...")
                        weight_df = pd.read_csv(weight_path, index_col=0)
                        metrics = calculate_metrics(weight_df, score_col="weight", label_col="is_causal", topk=20)
                        row = {
                            "Method": algo,
                            "Layer": layer_name,
                            "Dataset": i,
                            "Seed": seed,
                            "ScoreType": "weight",
                            **metrics,
                        }
                        all_results.append(row)
                        continue

                    print(f"[INFO] Evaluating {algo} on {layer_name} (Dataset {i + 1})...")
                    func = algorithm_functions[algo]
                    pred = func(sub_adata, output_dir)

                    pred_dict = {"weight": pred} if not isinstance(pred, dict) else pred
                    weight_df = pd.DataFrame(pred_dict, index=layer_vars)
                    weight_df["is_causal"] = adata.var.loc[layer_vars, "is_causal"].values

                    weight_df.to_csv(weight_path)
                    print(f"[INFO] Saved weights to: {weight_path}")

                    metrics = calculate_metrics(weight_df, score_col="weight", label_col="is_causal", topk=20)
                    row = {
                        "Method": algo,
                        "Layer": layer_name,
                        "Dataset": i,
                        "Seed": seed,
                        "ScoreType": "weight",
                        **metrics,
                    }
                    all_results.append(row)

        del adata
        gc.collect()

    # Save final metrics
    df = pd.DataFrame(all_results)
    metrics_path = os.path.join(output_dir, "Layerwise_Benchmark_Metrics.csv")
    df.to_csv(metrics_path, index=False)
    print(f"[INFO] Saved final evaluation to: {metrics_path}")

    # Draw plots
    plot_layerwise_metrics(
        df,
        output_dir,
        causal_strength=generation_args.get("causal_strength", 0.4),
        p_zero=generation_args.get("p_zero", 0.2),
    )

In [None]:
BASE_DIR = "/mnt/e/Project_Research/CauTrigger_Project/CauTrigger-master"
case_dir = os.path.join(BASE_DIR, "simulation")

# ==== 固定参数 ====
n_datasets = 10
spurious_mode = "semi_hrc"
n_hidden = 10
activation = "linear"
simulate_single_cell = True

# algorithms = ['CauTrigger_SHAP', 'SVM', 'RF', 'MI']
algorithms = ["CauTrigger_SHAP", "GENIE3", "SCRIBE", "PC", "VAEgrad", "DCI", "MI", "RF", "SVM"]
# algorithms = ['DCI',]

# ==== 你想扫的参数范围 ====
causal_strength_list = [0.3, 0.4, 0.5, 0.6, 0.7]  # 因果强度：低、中、高
p_zero_list = [0.1, 0.3, 0.5, 0.7]  # 稀疏度：低、中、高
causal_strength_list = [0.3, 0.4, 0.5]  # 因果强度：低、中、高
p_zero_list = [0.3, 0.5, 0.7]  # 稀疏度：低、中、高

for causal_strength in causal_strength_list:
    for p_zero in p_zero_list:
        # ==== 动态生成一组参数 ====
        generation_args = dict(
            spurious_mode=spurious_mode,
            n_hidden=n_hidden,
            activation=activation,
            causal_strength=causal_strength,
            p_zero=p_zero,
            simulate_single_cell=simulate_single_cell,
        )

        # ==== 动态生成 case_name ====
        tag_parts = [
            "2L_counts",
            spurious_mode,
            f"hidden{n_hidden}",
            activation,
            f"cs{int(causal_strength * 100):02d}",
            f"p{p_zero}",
            "sc" if simulate_single_cell else "bulk",
        ]
        case_name = "_".join(tag_parts)

        data_dir = os.path.join(case_dir, "data", case_name)
        output_dir = os.path.join(case_dir, "output", case_name)
        os.makedirs(data_dir, exist_ok=True)
        os.makedirs(output_dir, exist_ok=True)

        print(f"=== Running: {case_name} ===")

        run_benchmark(
            algorithms=algorithms, data_dir=data_dir, output_dir=output_dir, n_datasets=n_datasets, **generation_args
        )

# 最后生成汇总图（只处理当前这批组合）
aggregate_output_root = os.path.join(case_dir, "output")
plot_aggregate_layerwise_metrics(
    root_output_dir=os.path.join(case_dir, "output"),
    causal_strength_list=causal_strength_list,
    p_zero_list=p_zero_list,
    spurious_mode=spurious_mode,
    n_hidden=n_hidden,
    activation=activation,
    simulate_single_cell=simulate_single_cell,
)