In [None]:
# ===== Benchmark Script for Two-Layer Causal Simulation =====
# Goal: Evaluate layer-wise causal feature identification
import os
import sys
import math
import logging
import warnings
import numpy as np
import pandas as pd
from itertools import combinations
from scipy.stats import pearsonr, spearmanr, f_oneway, norm
from sklearn.feature_selection import mutual_info_classif
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import normalize as sk_normalize
from sklearn.preprocessing import minmax_scale
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import scanpy as sc
import ray
from ray import tune
import os
import gc

os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

from velorama.train import *
from velorama.utils import *
from CauTrigger.utils import set_seed, select_features
from CauTrigger.model import CauTrigger3L, CauTrigger2L, CauTrigger1L
from CauTrigger.dataloaders import generate_two_layer_synthetic_data

sys.path.append('../')  # 加入 CauTrigger 主目录
warnings.filterwarnings("ignore")
logging.getLogger('matplotlib.font_manager').disabled = True
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['font.family'] = 'sans-serif'

In [None]:
def run_CauTrigger(adata, output_dir=None, mode="SHAP", topk=30, is_log1p=False, full_input=False):
    """
    Run CauTrigger in two steps:
    Step 1: Use CauTrigger1L on layer1 (closer to Y) to get top-k causal genes.
    Step 2: Use those genes as X_down in CauTrigger2L to score layer2 (upstream).
    Return:
        dict {
            "layer1": DataFrame with scores (index = layer1 gene names),
            "layer2": DataFrame with scores (index = layer2 gene names)
        }
    """
    np.random.seed(42)
    if is_log1p:
        print("CauTrigger log1p normalization applied to adata.X, adata.obsm['layer1'], and adata.obsm['layer2']")
        adata.X = np.log1p(adata.X)
        adata.obsm["layer1"] = np.log1p(adata.obsm["layer1"])
        adata.obsm["layer2"] = np.log1p(adata.obsm["layer2"])
        adata._log1p_applied = True
        print("[INFO] log1p transformation applied to adata.obsm['layer1'] and ['layer2'].")

    # === Step 1: Score all genes ===
    adata_step1 = adata.copy()

    model_1L = CauTrigger1L(
        adata_step1,
        n_latent=10,
        n_hidden=128,
        n_layers_encoder=0,
        n_layers_decoder=0,
        n_layers_dpd=0,
        dropout_rate_encoder=0.0,
        dropout_rate_decoder=0.0,
        dropout_rate_dpd=0.0,
        use_batch_norm='none',
        use_batch_norm_dpd=True,
        decoder_linear=False,
        dpd_linear=True,
        init_weight=None,
        init_thresh=0.4,
        attention=False,
        att_mean=False,
    )
    model_1L.train(max_epochs=200, stage_training=True, weight_scheme="sim")
    df_step1, _ = model_1L.get_up_feature_weights(method=mode, normalize=False, sort_by_weight=True)
    print("df_step1", df_step1.head(20))

    # Step 2: select top-k genes from step1 as X_down
    topk_genes = df_step1["weight"].nlargest(topk).index.tolist()
    all_genes = adata.var_names.tolist()
    remaining_genes = [g for g in all_genes if g not in topk_genes]

    # 构建 X_down 和 X_upstream
    X_down = adata[:, topk_genes].X
    X_upstream = adata[:, remaining_genes].X

    # 构建 step2 的 AnnData
    adata_step2 = AnnData(
        X=X_upstream,
        obs=adata.obs.copy(),
        var=adata.var.loc[remaining_genes].copy(),
        obsm={"X_down": X_down if not full_input else adata.X}
    )

    model_2L = CauTrigger2L(
        adata_step2,
        n_latent=10,
        n_hidden=128,
        n_layers_encoder=0,
        n_layers_decoder=0,
        n_layers_dpd=0,
        dropout_rate_encoder=0.0,
        dropout_rate_decoder=0.0,
        dropout_rate_dpd=0.0,
        use_batch_norm='none',
        use_batch_norm_dpd=True,
        decoder_linear=False,
        dpd_linear=True,
        init_weight=None,
        init_thresh=0.1,
        attention=False,
        att_mean=False,
    )
    model_2L.train(max_epochs=200, stage_training=True, weight_scheme="sim")
    df_step2, _ = model_2L.get_up_feature_weights(method=mode, normalize=False, sort_by_weight=True)
    print("df_step2", df_step2.head(10))

    # Mark the estimation step
    df_step1["step"] = "step1"
    df_step2["step"] = "step2"

    # Amplify only top-k weights from step1
    amplify_factor = 2.0
    df_step1_topk = df_step1.loc[topk_genes].copy()
    df_step1_topk["weight"] *= amplify_factor

    # Concatenate step1_topk and all step2 genes (unaltered)
    df_all = pd.concat([df_step1_topk, df_step2])

    # Add ground truth causal label
    df_all["is_causal"] = adata.var.loc[df_all.index, "is_causal"]

    # Assign step labels
    df_all["step"] = ["step1" if g in topk_genes else "step2" for g in df_all.index]

    # Keep selected columns and align with original gene order
    df_all = df_all[["weight", "is_causal", "step"]]
    df_all = df_all.reindex(adata.var_names)  # Ensure consistent order

    # === Construct step1 and step2 full tables for separate layer evaluation ===
    df_step1_full = pd.DataFrame(index=adata.var_names)
    df_step1_full["weight"] = 0.0
    df_step1_full.loc[df_step1.index, "weight"] = df_step1["weight"]
    df_step1_full["step"] = "step1"
    df_step1_full["is_causal"] = 0
    layer1_genes = adata.var_names[adata.var["layer"] == "layer1"]
    df_step1_full.loc[layer1_genes, "is_causal"] = adata.var.loc[layer1_genes, "is_causal"]

    df_step2_full = pd.DataFrame(index=adata.var_names)
    df_step2_full["weight"] = 0.0
    df_step2_full.loc[df_step2.index, "weight"] = df_step2["weight"]
    df_step2_full["step"] = "step2"
    df_step2_full["is_causal"] = 0
    layer2_genes = adata.var_names[adata.var["layer"] == "layer2"]
    df_step2_full.loc[layer2_genes, "is_causal"] = adata.var.loc[layer2_genes, "is_causal"]

    # Return all outputs
    return {
        "step1": df_step1_full,
        "step2": df_step2_full,
        "all": df_all
    }

In [None]:
def run_PC(adata, output_dir):
    def gauss_ci_test(suff_stat, i, j, K):
        corr_matrix = suff_stat["C"]
        n_samples = suff_stat["n"]

        if len(K) == 0:
            r = corr_matrix[i, j]
        elif len(K) == 1:
            k = K[0]
            r = (corr_matrix[i, j] - corr_matrix[i, k] * corr_matrix[j, k]) / math.sqrt(
                (1 - corr_matrix[i, k] ** 2) * (1 - corr_matrix[j, k] ** 2)
            )
        else:
            sub_corr = corr_matrix[np.ix_([i, j] + K, [i, j] + K)]
            precision_matrix = np.linalg.pinv(sub_corr)
            r = (-1 * precision_matrix[0, 1]) / math.sqrt(
                abs(precision_matrix[0, 0] * precision_matrix[1, 1])
            )

        r = max(min(r, 0.99999), -0.99999)
        z = 0.5 * math.log1p((2 * r) / (1 - r))
        z_standard = z * math.sqrt(n_samples - len(K) - 3)
        p_value = 2 * (1 - norm.cdf(abs(z_standard)))

        return p_value

    def get_neighbors(G, x, exclude_y):
        return [i for i, connected in enumerate(G[x]) if connected and i != exclude_y]

    def skeleton(suff_stat, alpha):
        p_value_mat = np.zeros_like(suff_stat["C"])
        n_nodes = suff_stat["C"].shape[0]
        O = [[[] for _ in range(n_nodes)] for _ in range(n_nodes)]
        G = [[i != j for i in range(n_nodes)] for j in range(n_nodes)]
        pairs = [(i, j) for i in range(n_nodes) for j in range(i + 1, n_nodes)]

        done = False
        l = 0

        while not done and any(any(row) for row in G):
            done = True

            for x, y in pairs:
                if G[x][y]:
                    neighbors = get_neighbors(G, x, y)
                    if len(neighbors) >= l:
                        done = False
                        for K in combinations(neighbors, l):
                            p_value = gauss_ci_test(suff_stat, x, y, list(K))
                            if p_value > p_value_mat[x][y]:
                                p_value_mat[x][y] = p_value_mat[y][x] = p_value
                            if p_value >= alpha:
                                G[x][y] = G[y][x] = False
                                O[x][y] = O[y][x] = list(K)
                                break
            l += 1

        return np.asarray(G, dtype=int), O, p_value_mat

    def extend_cpdag(G, O):
        n_nodes = G.shape[0]

        def rule1(g):
            pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 0]
            for i, j in pairs:
                all_k = [k for k in range(n_nodes) if
                         (g[j][k] == 1 and g[k][j] == 1) and (g[i][k] == 0 and g[k][i] == 0)]
                for k in all_k:
                    g[j][k] = 1
                    g[k][j] = 0
            return g

        def rule2(g):
            pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 1]
            for i, j in pairs:
                all_k = [k for k in range(n_nodes) if
                         (g[i][k] == 1 and g[k][i] == 0) and (g[k][j] == 1 and g[j][k] == 0)]
                if len(all_k) > 0:
                    g[i][j] = 1
                    g[j][i] = 0
            return g

        def rule3(g):
            pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 1]
            for i, j in pairs:
                all_k = [k for k in range(n_nodes) if
                         (g[i][k] == 1 and g[k][i] == 1) and (g[k][j] == 1 and g[j][k] == 0)]
                if len(all_k) >= 2:
                    for k1, k2 in combinations(all_k, 2):
                        if g[k1][k2] == 0 and g[k2][k1] == 0:
                            g[i][j] = 1
                            g[j][i] = 0
                            break
            return g

        pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if G[i][j] == 1]
        for x, y in sorted(pairs, key=lambda x: (x[1], x[0])):
            all_z = [z for z in range(n_nodes) if G[y][z] == 1 and z != x]
            for z in all_z:
                if G[x][z] == 0 and y not in O[x][z]:
                    G[x][y] = G[z][y] = 1
                    G[y][x] = G[y][z] = 0

        old_G = np.zeros((n_nodes, n_nodes))
        while not np.array_equal(old_G, G):
            old_G = G.copy()
            G = rule1(G)
            G = rule2(G)
            G = rule3(G)

        return np.array(G)

    def pc(suff_stat, alpha=0.5, verbose=False):
        G, O, pvm = skeleton(suff_stat, alpha)
        cpdag = extend_cpdag(G, O)
        if verbose:
            print(cpdag)
        return cpdag, pvm

    alpha = 0.05
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:  # Rough check for count data
        X = np.log1p(X)  # Apply log1p for count data
    y = adata.obs['labels'].values
    data = pd.DataFrame(np.column_stack((X, y)))
    cpdag, pvm = pc(
        suff_stat={"C": data.corr().values, "n": data.shape[0]},
        alpha=alpha
    )
    pv = pvm[:-1, -1]
    arr = np.array(1 - pv).reshape(1, -1)  # 需要转换成 2D
    normalized_arr = sk_normalize(arr, norm='l1', axis=1)
    return (normalized_arr.flatten())


def run_VAE(adata, output_dir):
    import torch
    from torch import nn
    from torch.utils.data import TensorDataset, DataLoader
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:  # Rough check for count data
        X = np.log1p(X)  # Apply log1p for count data
    y = adata.obs['labels'].values
    n_features = X.shape[1]
    features = torch.tensor(X, dtype=torch.float32)
    labels = torch.tensor(y, dtype=torch.float32).view(-1, 1)
    dataset = TensorDataset(features, labels)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

    n_hidden = 5
    n_latent = 5

    class VAE(nn.Module):
        def __init__(self, num_features):
            super().__init__()

            self.encoder = nn.Sequential(
                nn.Linear(num_features, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, 2 * n_latent),
            )

            self.decoder = nn.Sequential(
                nn.Linear(n_latent, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, num_features),
                # nn.Sigmoid()
            )

            self.DPD = nn.Sequential(
                nn.Linear(n_latent, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, 1),
                nn.Sigmoid(),
            )

        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps * std + mu

        def forward(self, x):
            mu_logvar = self.encoder(x)
            mu = mu_logvar[:, :n_latent]
            logvar = mu_logvar[:, n_latent:]
            z = self.reparameterize(mu, logvar)
            y = self.DPD(z)
            reconstructed = self.decoder(z)
            return reconstructed, y, mu, logvar

    model = VAE(n_features)
    recon_criterion = nn.MSELoss()
    dpd_criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # train
    model.train()
    losses = []
    re_losses = []
    kl_losses = []
    dpd_losses = []
    for epoch in range(200):
        for data, targets in dataloader:
            optimizer.zero_grad()
            recon_batch, y_dpd, mu, logvar = model(data)
            # reconstructed loss
            re_loss = recon_criterion(recon_batch, data)
            re_losses.append(re_loss.item())

            # kl loss
            kl_loss = (
                -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / data.shape[0]
            )
            kl_losses.append(kl_loss.item())

            # dpd loss
            dpd_loss = dpd_criterion(y_dpd, targets)
            dpd_losses.append(dpd_loss.item())

            # total loss
            if epoch <= 100:
                loss = re_loss + kl_loss * 0.1 + dpd_loss * 0.1
            else:
                loss = re_loss + kl_loss * 0.1 + dpd_loss * 0.1

            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    model.eval()

    # Grad
    features.requires_grad = True
    _, y_prob, _, _ = model(features)
    loss = dpd_criterion(y_prob, labels)
    loss.backward()
    grads = features.grad.abs()
    grad_features_importance = grads.mean(dim=0)
    # grad_df = var_df.copy()
    grad_df = grad_features_importance.detach().numpy()
    arr = np.array(grad_df).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm='l1', axis=1)
    return normalized_arr.flatten()


def run_RF(adata, output_dir):
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:  # Rough check for count data
        X = np.log1p(X)  # Apply log1p for count data
    # X = (X - X.mean(axis=0)) / X.std(axis=0)
    y = adata.obs['labels'].values
    rf = RandomForestClassifier()
    rf.fit(X, y)
    arr = np.array(rf.feature_importances_.flatten()).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm='l1', axis=1)
    return normalized_arr.flatten()


def run_SVM(adata, output_dir):
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:  # Rough check for count data
        X = np.log1p(X)  # Apply log1p for count data
    X = (X - X.mean(axis=0)) / X.std(axis=0)
    y = adata.obs['labels'].values
    svm = SVC(kernel='linear')
    svm.fit(X, y)
    arr = np.array(svm.coef_.flatten()).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm='l1', axis=1)
    return normalized_arr.flatten()


def run_MI(adata, output_dir):
    X = adata.X
    if np.issubdtype(X.dtype, np.integer) or X.max() > 100:  # Rough check for count data
        X = np.log1p(X)  # Apply log1p for count data
    X = (X - X.mean(axis=0)) / X.std(axis=0)
    y = adata.obs['labels'].values
    arr =  np.array(mutual_info_classif(X, y).flatten()).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm='l1', axis=1)
    return normalized_arr.flatten()


def run_VELORAMA(adata, output_dir):
    target_genes = adata.var_names.tolist()
    reg_genes = adata.var_names.tolist()

    X_orig = adata.X.copy()
    if np.issubdtype(X_orig.dtype, np.integer) or X_orig.max() > 100:  # Rough check for count data
        X_orig = np.log1p(X_orig)  # Apply log1p for count data
    std = X_orig.std(0)
    std[std == 0] = 1
    X = torch.FloatTensor(X_orig - X_orig.mean(0)) / std
    X = X.to(torch.float32)

    Y_orig = adata.X.copy()
    std = Y_orig.std(0)
    std[std == 0] = 1
    Y = torch.FloatTensor(Y_orig - Y_orig.mean(0)) / std
    Y = Y.to(torch.float32)

    adata.uns['iroot'] = 0
    sc.pp.scale(adata)

    reg_target = 0
    dynamics = 'pseudotime'
    ptloc = 'pseudotime'
    proba = 1
    n_neighbors = 30
    velo_mode = 'stochastic'
    time_series = 0
    n_comps = 20
    lag = 5
    name = 'velorama_run'
    seed = 42
    hidden = 32
    penalty = 'H'
    save_dir = output_dir
    lam_start = -2
    lam_end = 1
    num_lambdas = 19

    # A 邻接矩阵，
    # AX AY 结果是一个 (lag × cells × genes) 的张量，表示不同时间步的扩散特征：计算 A * X，表示通过邻接矩阵传播 X。计算 A^2 * X，即将 A 再次传播。持续进行 lag 次，存储 A^t * X。
    A = construct_dag(adata, dynamics=dynamics, ptloc=ptloc, proba=proba,
                      n_neighbors=n_neighbors, velo_mode=velo_mode,
                      use_time=time_series, n_comps=n_comps)
    A = torch.FloatTensor(A)
    AX = calculate_diffusion_lags(A, X, lag)
    AY = None

    dir_name = '{}.seed{}.h{}.{}.lag{}.{}'.format(name, seed, hidden, penalty, lag, dynamics)

    if not os.path.exists(os.path.join(save_dir, dir_name)):
        os.makedirs(os.path.join(save_dir, dir_name), exist_ok=True)

    ray.init(object_store_memory=6 * 1024 ** 3, ignore_reinit_error=True)  # 可以调大
    # ray.init(local_mode=True)  # 就可以设置断点啦！

    lam_list = np.logspace(lam_start, lam_end, num=num_lambdas).tolist()

    config = {'name': name,
              'AX': AX,
              'AY': AY,
              'Y': Y,
              'seed': seed,
              'lr': 0.01,
              'lam': tune.grid_search(lam_list),
              'lam_ridge': 0.0,
              'penalty': penalty,
              'lag': lag,
              'hidden': [hidden],
              'max_iter': 200,
              'device': 'cpu',
              'lookback': 5,
              'check_every': 10,
              'verbose': True,
              'dynamics': dynamics,
              'results_dir': save_dir,
              'dir_name': dir_name,
              'reg_target': reg_target}
    resources_per_trial = {"cpu": 1, "gpu": 0, "memory": 1 * 1024 ** 3}  # 可以调大
    analysis = tune.run(train_model, resources_per_trial=resources_per_trial, config=config, storage_path=save_dir)

    target_dir = os.path.join(save_dir, dir_name)
    base_dir = save_dir
    move_files(base_dir, target_dir)

    # aggregate results
    lam_list = [np.round(lam, 4) for lam in lam_list]
    all_lags = load_gc_interactions(name, save_dir, lam_list, hidden_dim=hidden, lag=lag, penalty=penalty, dynamics=dynamics, seed=seed, ignore_lag=False)  #形状为[lam_count, TG_count, TF_count, lag]

    gc_mat = estimate_interactions(all_lags, lag=lag)  # tg_count x tf_count
    gc_df = pd.DataFrame(gc_mat.cpu().data.numpy(), index=target_genes, columns=reg_genes)

    ray.shutdown()
    return gc_df.mean(axis=0).values


def run_SCRIBE(adata, output_dir):
    from Scribe.read_export import load_anndata
    adata = adata.copy()
    sc.pp.log1p(adata)
    adata.uns['iroot'] = 0
    sc.pp.scale(adata)
    sc.pp.neighbors(adata)
    sc.tl.dpt(adata)
    adata.obs['dpt_groups'] = ['0' if i < adata.obs['dpt_pseudotime'].median() else '1' for i in adata.obs['dpt_pseudotime']]

    model = load_anndata(adata)
    model.rdi(delays=[1,2,3], number_of_processes=1, uniformization=False, differential_mode=False)  # dict_keys([1, 2, 3, 'MAX'])

    edges = []
    values = []
    for id1 in adata.var_names:
        for id2 in adata.var_names:
            if id1 == id2: continue
            edges.append(id1.lower() + "\t" + id2.lower())
            values.append(model.rdi_results["MAX"].loc[id1, id2])

    edges_values = [[edges[i], values[i]] for i in range(len(edges))]
    df = pd.DataFrame(edges_values, columns=['Edge', 'Value'])
    df[['Source', 'Target']] = df['Edge'].str.split('\t', expand=True)
    df_sorted = df[['Source', 'Target', 'Value']].sort_values(by='Value', ascending=False)
    df_mean = df_sorted.groupby("Source")["Value"].mean().reset_index()
    df_mean = df_mean.set_index("Source").loc[list(adata.var_names)]

    return np.array(df_mean['Value'])

def run_DCI(adata, output_dir):
    from causaldag import dci
    from collections import Counter
    import itertools as itr
    import scipy
    adata = adata.copy()
    sc.pp.log1p(adata)
    full_genes = adata.var_names.copy()  # 保存全集顺序
    mean1 = adata[adata.obs['labels'] == 1].X.mean(axis=0)
    mean0 = adata[adata.obs['labels'] == 0].X.mean(axis=0)
    logfc = np.log2((mean1 + 1e-9) / (mean0 + 1e-9)).A1 if scipy.sparse.issparse(adata.X) else np.log2(
        (mean1 + 1e-9) / (mean0 + 1e-9))
    top_idx = np.argsort(np.abs(logfc))[::-1][:50]
    adata = adata[:, top_idx]  # 截断只保留top50

    X1 = adata.X[adata.obs['labels'] == 0].astype(float)
    X2 = adata.X[adata.obs['labels'] == 1].astype(float)
    X1 += np.random.normal(0, 1e-6, size=X1.shape)
    X2 += np.random.normal(0, 1e-6, size=X2.shape)
    p = X1.shape[1]
    if scipy.sparse.issparse(adata.X):
        X_full = adata.X.toarray()
    else:
        X_full = adata.X
    corr = np.corrcoef(X_full.T)
    threshold = 0.3  # 可调
    candidate_edges = [(i, j) for i in range(p) for j in range(i + 1, p) if abs(corr[i, j]) >= threshold]
    print(f"[INFO] Filtered candidate edges: {len(candidate_edges)}")
    difference_matrix = dci(X1, X2, difference_ug_method='constraint', difference_ug=candidate_edges, alpha_ug=0.05, alpha_skeleton=0.1, max_set_size=2)
    ddag_edges = set(zip(*np.where(difference_matrix != 0)))
    print("len(ddag_edges)", len(ddag_edges))
    count_dict = Counter([node for edge in ddag_edges for node in edge])
    count_df = pd.DataFrame(count_dict.items(), columns=['node', 'count']).sort_values('count', ascending=False)
    count_df['gene'] = [adata.var_names[i] for i in count_df['node']]
    full_df = pd.DataFrame({'gene': full_genes})
    count_df = full_df.merge(count_df[['gene', 'count']], on='gene', how='left').fillna({'count': 0})

    arr = np.array(count_df['count'].values).reshape(1, -1)
    normalized_arr = sk_normalize(arr, norm='l1', axis=1)
    return normalized_arr.flatten()


def run_GENIE3(adata, output_dir):
    from GENIE3 import GENIE3, get_link_list
    import tempfile
    import os
    adata = adata.copy()
    sc.pp.log1p(adata)
    adata1 = adata[adata.obs['labels'] == 1]
    adata0 = adata[adata.obs['labels'] == 0]

    with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp_file:
        tmp_path = tmp_file.name
    X = adata0.X
    gene_names = list(adata.var_names)
    VIM = GENIE3(X , gene_names=gene_names)
    get_link_list(VIM, gene_names=gene_names, file_name=tmp_path)
    df0 = pd.read_csv(tmp_path, sep='\t', header=None)
    os.remove(tmp_path)

    with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp_file:
        tmp_path = tmp_file.name
    X = adata1.X
    gene_names = list(adata.var_names)
    VIM = GENIE3(X, gene_names=gene_names)
    get_link_list(VIM, gene_names=gene_names, file_name=tmp_path)
    df1 = pd.read_csv(tmp_path, sep='\t', header=None)
    os.remove(tmp_path)

    merged = pd.merge(df1, df0, on=[0, 1], suffixes=('_net1', '_net0'))
    merged['Diff_Strength'] = abs(merged['2_net1'] - merged['2_net0'])
    diff_network = merged[[0, 1, 'Diff_Strength']]
    arr = diff_network.groupby(0)['Diff_Strength'].mean()
    arr_sorted = arr.reindex(adata.var_names)
    normalized_arr = minmax_scale(arr_sorted)

    return normalized_arr

In [None]:
def select_single_feature(score_array, threshold=None, topk=None):
    """
    Select binary labels (1 for selected, 0 for not) from a 1D score array,
    using either a threshold or top-k strategy.
    """
    score_array = np.asarray(score_array)
    n = len(score_array)

    if topk is not None:
        topk_indices = np.argsort(score_array)[-topk:]
        selected = np.zeros(n, dtype=int)
        selected[topk_indices] = 1
        return selected

    elif threshold is not None:
        return (score_array >= threshold).astype(int)

    else:
        raise ValueError("Either threshold or topk must be specified.")


def calculate_metrics(weight_df, label_col="is_causal", score_col="weight", threshold=None, topk=None):
    from sklearn.metrics import (roc_curve, auc, confusion_matrix, accuracy_score, matthews_corrcoef,
                                 f1_score, precision_score, recall_score, precision_recall_curve)

    true_label = weight_df[label_col].values
    score_array = weight_df[score_col].values

    # Get predicted label by top-k or threshold
    pred_label = select_single_feature(score_array, threshold=threshold, topk=topk)

    # AUROC
    fpr, tpr, _ = roc_curve(true_label, score_array)
    roc_auc = auc(fpr, tpr)

    # AUPR
    precision_curve, recall_curve, _ = precision_recall_curve(true_label, score_array)
    aupr = auc(recall_curve, precision_curve)

    # Confusion matrix
    cm = confusion_matrix(true_label, pred_label)
    TN, FP = cm[0, 0], cm[0, 1]
    specificity = TN / (TN + FP) if (TN + FP) > 0 else 0.0

    # Other metrics
    acc = accuracy_score(true_label, pred_label)
    mcc = matthews_corrcoef(true_label, pred_label)
    precision = precision_score(true_label, pred_label, pos_label=1, zero_division=0)
    recall = recall_score(true_label, pred_label, pos_label=1, zero_division=0)
    f1 = f1_score(true_label, pred_label, pos_label=1, zero_division=0)

    return {
        "AUROC": roc_auc,
        "AUPR": aupr,
        "F1": f1,
        "ACC": acc,
        "MCC": mcc,
        "Precision": precision,
        "Recall": recall,
        "Specificity": specificity
    }



def evaluate_prediction(true_label, pred_dict, topk=10):
    """
    Evaluate each prediction array in pred_dict using standard metrics.
    """
    results = []
    for submethod, pred in pred_dict.items():
        metrics = calculate_metrics(true_label, pred, topk=topk)
        results.append({
            'SubMethod': submethod,
            'AUROC': metrics[0],
            'AUPR': metrics[1],
            'F1': metrics[2],
            'ACC': metrics[3],
            'MCC': metrics[4],
            'Precision': metrics[5],
            'Recall': metrics[6],
            'Specificity': metrics[7],
        })
    return results


def plot_layerwise_metrics1(df, output_dir, causal_strength=0.4, p_zero=0.2):
    """
    Generate Boxplot, Violinplot, and Barplot (mean±std) for each metric (AUROC, AUPR)
    with Nature Methods publication-quality styling.
    """

    # --- Global style settings ---
    sns.set_theme(style="white")
    plt.rcParams.update({
        "font.size": 14,
        "axes.labelsize": 16,
        "axes.titlesize": 18,
        "xtick.labelsize": 14,
        "ytick.labelsize": 14,
        "legend.fontsize": 14,
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    })

    # --- Standardize method names ---
    df['Method'] = df['Method'].replace({'CauTrigger_SHAP': 'CauTrigger','VAEgrad': 'VAE'})

    # --- Fix method order ---
    method_order = ['CauTrigger', 'GENIE3', 'SCRIBE', 'PC', 'VAE', 'DCI', 'MI', 'RF', 'SVM']
    df['Method'] = pd.Categorical(df['Method'], categories=method_order, ordered=True)

    # --- Nature family color palette ---
    nature_colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2", "#CCB974", "#64B5CD"]

    metrics = ['AUROC', 'AUPR']

    for metric in metrics:
        # --- Boxplot with points overlay ---
        plt.figure(figsize=(10, 6))
        ax = sns.boxplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            palette=nature_colors,
            width=0.6,
            fliersize=0,  # 不画异常值（用点图代替）
            linewidth=1.5
        )

        # Add individual data points on top of boxplot
        sns.stripplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            dodge=True,
            palette=nature_colors,
            alpha=0.5,
            jitter=0.2,
            marker='o',
            edgecolor="gray",
            linewidth=0.5
        )

        # Remove duplicate legends caused by stripplot
        handles, labels = ax.get_legend_handles_labels()
        n_layers = len(df['Layer'].unique())
        plt.legend(handles[:n_layers], labels[:n_layers], title="Layer", frameon=False, loc="best")

        plt.title(f"{metric} (cs={causal_strength}, p={p_zero})", pad=15)
        plt.ylabel(metric)
        plt.xlabel("")
        plt.xticks(rotation=0, ha='center')
        plt.ylim(0, 1.05)
        sns.despine()
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{metric}_Comparison_Boxplot.pdf'))
        plt.savefig(os.path.join(output_dir, f'{metric}_Comparison_Boxplot.png'))
        plt.close()

        # --- Violinplot ---
        plt.figure(figsize=(10, 6))
        sns.violinplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            palette=nature_colors,
            inner="point",  # Show individual points
            cut=0,
            scale="width",  # Uniform violin width
            bw=0.4,  # Slightly larger bandwidth (default 0.2)
            width=0.7,
            linewidth=1.0,
            dodge=True,
            saturation=0.8
        )
        plt.title(f"{metric} (cs={causal_strength}, p={p_zero})", pad=15)
        plt.ylabel(metric)
        plt.xlabel("")
        plt.xticks(rotation=0, ha='center')
        plt.ylim(0, 1.05)
        plt.legend(title="Layer", frameon=False, loc="best")
        sns.despine()
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{metric}_Comparison_Violinplot.pdf'))
        plt.savefig(os.path.join(output_dir, f'{metric}_Comparison_Violinplot.png'))
        plt.close()

        # --- Barplot (mean ± std) ---
        plt.figure(figsize=(10, 6))
        summary_df = df.groupby(["Method", "Layer"])[metric].agg(["mean", "std"]).reset_index()

        ax = sns.barplot(
            data=summary_df,
            x="Method",
            y="mean",
            hue="Layer",
            palette=nature_colors,
            errorbar=None,
            width=0.7
        )

        # Manually add error bars centered at each bar
        patches = ax.patches
        for patch, (_, row) in zip(patches, summary_df.iterrows()):
            x = patch.get_x() + patch.get_width() / 2
            y = patch.get_height()
            yerr = row["std"]
            ax.errorbar(
                x=x,
                y=y,
                yerr=yerr,
                fmt='none',
                ecolor='black',
                elinewidth=1.5,
                capsize=4,
                capthick=1.5
            )

        plt.ylabel(f"{metric} (Mean ± SD)")
        plt.title(f"{metric} (cs={causal_strength}, p={p_zero})", pad=15)
        plt.xlabel("")
        plt.xticks(rotation=0, ha='center')
        plt.ylim(0, 1.05)
        plt.legend(title="Layer", frameon=False, loc="upper right")
        sns.despine()
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{metric}_Comparison_Barplot.pdf'))
        plt.savefig(os.path.join(output_dir, f'{metric}_Comparison_Barplot.png'))
        plt.close()


def plot_layerwise_metrics(df, output_dir, causal_strength=0.4, p_zero=0.2):
    """
    Generate Boxplot, Violinplot, and Barplot (mean±std) for each metric (AUROC, AUPR)
    with solid box color, improved clarity, and Nature Methods-compatible visuals.
    """
    import os
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd

    sns.set_theme(style="white")
    plt.rcParams.update({
        "font.size": 14,
        "axes.labelsize": 16,
        "axes.titlesize": 18,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 12,
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "Arial",
    })

    df = df.copy()
    df['Method'] = df['Method'].replace({'CauTrigger_SHAP': 'CauTrigger', 'VAEgrad': 'VAE'})
    method_order = ['CauTrigger', 'GENIE3', 'SCRIBE', 'PC', 'VAE', 'DCI', 'MI', 'RF', 'SVM']
    df['Method'] = pd.Categorical(df['Method'], categories=method_order, ordered=True)
    method_count = len(method_order)

    layer_palette = {
        'layer1': '#3E4A89',  # 深蓝紫 → 下游/靠近表型
        'layer2': '#74A9CF',  # 浅蓝 → 上游/调控
        'all': '#5B9BD5',  # 🔁 改成清爽现代蓝
    }

    metrics = ['AUROC', 'AUPR']

    for metric in metrics:
        cs_tag = f"cs{int(causal_strength * 100)}"
        p_tag = f"p{p_zero}"
        metric_tag = metric.lower()

        # === Boxplot (Optimized) ===
        plt.figure(figsize=(7, 3.5))
        ax = sns.boxplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            palette=layer_palette,
            width=0.6,
            fliersize=0,
            linewidth=0.8,
            # boxprops=dict(edgecolor='#999999', linewidth=0.8),
            # whiskerprops=dict(color='#999999', linewidth=0.8),
            # capprops=dict(color='#999999', linewidth=0.8),
            # medianprops=dict(color='black', linewidth=1.5)
        )

        # 半透明填充颜色，柔和视觉
        for patch, (_, group) in zip(ax.patches, df.groupby(["Method", "Layer"])):
            layer = group["Layer"].iloc[0]
            face_color = layer_palette.get(layer, "#dddddd")
            patch.set_facecolor(face_color)
            patch.set_alpha(0.9)

        # 替代黑点为灰点，更柔和
        sns.stripplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            dodge=True,
            color='gray',
            size=2,
            jitter=0.25,
            alpha=0.3,
            edgecolor=None,
            linewidth=0,
            legend=False
        )

        # 分隔线更浅更细
        for i in range(1, method_count):
            ax.axvline(i - 0.5, linestyle='--', color='lightgray', linewidth=0.3, zorder=0)

        # Legend
        handles, labels = ax.get_legend_handles_labels()
        unique_layers = list(dict.fromkeys(zip(labels, handles)))
        labels, handles = zip(*unique_layers)
        plt.legend(
            handles,
            labels,
            title="Layer",
            frameon=False,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.18),
            ncol=len(labels),
            fontsize=11,
            title_fontsize=11,
            columnspacing=1.2,
            handlelength=1.5
        )

        plt.ylabel(metric, fontsize=13)
        plt.xlabel("")
        plt.xticks(rotation=30, ha='right', fontsize=11)
        plt.yticks(fontsize=11)
        plt.ylim(0, 1.05)
        sns.despine()
        plt.tight_layout()
        fname = f"{metric_tag}-boxplot-{cs_tag}-{p_tag}"
        plt.savefig(os.path.join(output_dir, f"{fname}.pdf"))
        plt.savefig(os.path.join(output_dir, f"{fname}.png"))
        plt.close()

        # === Violinplot ===
        plt.figure(figsize=(7, 3.5))
        ax = sns.violinplot(
            data=df,
            x="Method",
            y=metric,
            hue="Layer",
            palette=layer_palette,
            inner="quartile",
            cut=0,
            scale="width",
            bw=0.4,
            width=0.7,
            linewidth=1.0,
            dodge=True,
            saturation=0.8
        )
        for i in range(1, method_count):
            ax.axvline(i - 0.5, linestyle='--', color='lightgray', linewidth=0.5, zorder=0)
        plt.ylabel(metric, fontsize=13)
        plt.xlabel("")
        plt.xticks(rotation=30, ha='right', fontsize=11)
        plt.yticks(fontsize=11)
        plt.ylim(0, 1.05)
        plt.legend(
            title="Layer",
            frameon=False,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.18),
            ncol=len(set(df["Layer"])),
            fontsize=11,
            title_fontsize=11,
            columnspacing=1.2,
            handlelength=1.5
        )
        sns.despine()
        plt.tight_layout()
        fname = f"{metric_tag}-violinplot-{cs_tag}-{p_tag}"
        plt.savefig(os.path.join(output_dir, f"{fname}.pdf"))
        plt.savefig(os.path.join(output_dir, f"{fname}.png"))
        plt.close()

        # === Barplot ===
        summary_df = df.groupby(["Method", "Layer"])[metric].agg(["mean", "std"]).reset_index()
        plt.figure(figsize=(7, 3.5))
        ax = sns.barplot(
            data=summary_df,
            x="Method",
            y="mean",
            hue="Layer",
            palette=layer_palette,
            errorbar=None,
            width=0.7
        )
        for i in range(1, method_count):
            ax.axvline(i - 0.5, linestyle='--', color='lightgray', linewidth=0.5, zorder=0)
        patches = ax.patches
        for patch, (_, row) in zip(patches, summary_df.iterrows()):
            x = patch.get_x() + patch.get_width() / 2
            y = patch.get_height()
            yerr = row["std"]
            ax.errorbar(
                x=x,
                y=y,
                yerr=yerr,
                fmt='none',
                ecolor='black',
                elinewidth=1.5,
                capsize=4,
                capthick=1.5
            )
        max_y = summary_df["mean"].max() + summary_df["std"].max() + 0.05
        plt.ylim(0, max(1.05, max_y))
        plt.ylabel(f"{metric} (Mean ± SD)", fontsize=13)
        plt.xlabel("")
        plt.xticks(rotation=30, ha='right', fontsize=11)
        plt.yticks(fontsize=11)
        plt.legend(
            title="Layer",
            frameon=False,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.18),
            ncol=len(set(df["Layer"])),
            fontsize=11,
            title_fontsize=11,
            columnspacing=1.2,
            handlelength=1.5
        )
        sns.despine()
        plt.tight_layout()
        fname = f"{metric_tag}-barplot-{cs_tag}-{p_tag}"
        plt.savefig(os.path.join(output_dir, f"{fname}.pdf"))
        plt.savefig(os.path.join(output_dir, f"{fname}.png"))
        plt.close()



def plot_aggregate_layerwise_metrics(
    root_output_dir,
    causal_strength_list,
    p_zero_list,
    spurious_mode='semi_hrc',
    n_hidden=10,
    activation='linear',
    simulate_single_cell=True
):
    """
    汇总指定参数组合下的 Layerwise_Benchmark_Metrics.csv，并生成 AUROC / AUPR 的 3x3 boxplot 总图。
    图例移至标题下方，适配 Nature Methods 风格。
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import os

    sns.set_theme(style="white")
    plt.rcParams.update({
        "font.size": 12,
        "axes.labelsize": 14,
        "axes.titlesize": 16,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 14,
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "Arial",
    })

    all_dfs = []

    for cs in causal_strength_list:
        for pz in p_zero_list:
            case_name = "_".join([
                "2L_unknown",
                spurious_mode,
                f"hidden{n_hidden}",
                activation,
                f"cs{int(cs * 100):02d}",
                f"p{pz}",
                "sc" if simulate_single_cell else "bulk",
            ])
            csv_path = os.path.join(root_output_dir, case_name, "Layerwise_Benchmark_Metrics.csv")
            if os.path.exists(csv_path):
                df = pd.read_csv(csv_path)
                df["ParamCombo"] = f"Causal Strength = {cs}, Sparsity = {pz}"
                all_dfs.append(df)

    if not all_dfs:
        print("[WARN] No matching benchmark files found.")
        return

    df = pd.concat(all_dfs, ignore_index=True)
    df['Method'] = df['Method'].replace({'CauTrigger_SHAP': 'CauTrigger', 'VAEgrad': 'VAE'})
    df['Method'] = pd.Categorical(df['Method'], categories=[
        'CauTrigger', 'GENIE3', 'SCRIBE', 'PC', 'VAE', 'DCI', 'MI', 'RF', 'SVM'
    ], ordered=True)

    tag_parts = [
        "2L_unknown",  # 👈 替换为层级未知前缀
        spurious_mode,
        f"hidden{n_hidden}",
        activation,
        "sc" if simulate_single_cell else "bulk",
    ]
    base_tag = "_".join(tag_parts)

    for metric in ["AUROC", "AUPR"]:
        fig, axes = plt.subplots(3, 3, figsize=(18, 12), sharey=True)
        param_combos = sorted(df['ParamCombo'].unique())

        for ax, combo in zip(axes.flatten(), param_combos):
            subdf = df[df['ParamCombo'] == combo]
            sns.boxplot(
                data=subdf,
                x="Method",
                y=metric,
                hue="Layer",
                palette = {"all": "#5B9BD5"},  # ‘all’ 色
                ax=ax,
                width=0.6,
                fliersize=0,
                linewidth=1,
            )
            ax.set_title(combo, fontsize=14)
            ax.set_xlabel("")
            ax.set_ylabel(metric)
            ax.tick_params(axis='x', rotation=30)
            ax.legend_.remove()
            ax.grid(False)

            # n_methods = df['Method'].nunique()
            # for i in range(1, n_methods):
            #     ax.axvline(x=i - 0.5, linestyle='--', color='lightgray', linewidth=0.6, zorder=0)

        # 设置主标题，并调整 y 值（不再为 legend 腾出空间）
        fig.suptitle(f"Overall {metric} across methods", fontsize=18, y=1.02)

        # 添加统一 legend，放在标题正下方
        handles, labels = ax.get_legend_handles_labels()
        fig.legend(
            handles, labels,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.0),
            ncol=2,
            frameon=False
        )

        plt.tight_layout()
        out_prefix = f"{base_tag}_boxplot_{metric}"
        fig.savefig(os.path.join(root_output_dir, f"{out_prefix}.pdf"), bbox_inches='tight')
        fig.savefig(os.path.join(root_output_dir, f"{out_prefix}.png"), bbox_inches='tight')
        print(f"[INFO] Saved: {out_prefix}.pdf/.png")

        plt.close(fig)

In [None]:
def run_benchmark(
    algorithms,
    data_dir,
    output_dir,
    n_datasets=3,
    seed_list=None,
    save_adata=True,
    rerun=False,
    **generation_args
):
    """
    Run benchmark evaluation for causal discovery methods on synthetic two-layer datasets.
    Supports both single-layer methods and recursive multi-layer methods like CauTrigger.
    """
    algorithm_functions = {
        # CauTrigger scoring variants
        "CauTrigger_Model": lambda adata, out: run_CauTrigger(adata, out, mode="Model"),
        "CauTrigger_Grad": lambda adata, out: run_CauTrigger(adata, out, mode="Grad"),
        "CauTrigger_SHAP": lambda adata, out: run_CauTrigger(adata, out, mode="SHAP"),
        "CauTrigger_Ensemble": lambda adata, out: run_CauTrigger(adata, out, mode="Ensemble"),

        # Other baseline methods
        "PC": run_PC,
        "VAEgrad": run_VAE,
        "SVM": run_SVM,
        "RF": run_RF,
        "MI": run_MI,
        "DCI": run_DCI,
        # "NLBayes": run_NLBAYES,
        "GENIE3": run_GENIE3,
        # "GRNBOOST2": run_GRNBOOST2,
        "SCRIBE": run_SCRIBE,
        "VELORAMA": run_VELORAMA,
    }

    print(f"[INFO] Running benchmark on {n_datasets} datasets with algorithms: {algorithms}")

    if seed_list is None:
        seed_list = list(range(n_datasets))
    else:
        if len(seed_list) < n_datasets:
            max_seed = max(seed_list)
            seed_list += list(range(max_seed + 1, max_seed + 1 + (n_datasets - len(seed_list))))
        elif len(seed_list) > n_datasets:
            seed_list = seed_list[:n_datasets]

    os.makedirs(output_dir, exist_ok=True)

    all_results = []
    stepwise_results = []

    for i, seed in enumerate(seed_list):
        set_seed(seed)
        adata = generate_two_layer_synthetic_data(seed=seed, **generation_args)
        print(f"[INFO] Dataset {i + 1}: Generated with seed {seed}")

        if save_adata:
            dataset_dir = os.path.join(data_dir, f'dataset{i + 1}')
            os.makedirs(dataset_dir, exist_ok=True)
            adata.write(os.path.join(dataset_dir, 'adata.h5ad'))

        for algo in algorithms:
            assert algo in algorithm_functions, f"[ERROR] Algorithm '{algo}' not registered in algorithm_functions!"

            if algo.startswith("CauTrigger"):
                layer_name = "all"
                weight_path = os.path.join(output_dir, f'weights_dataset{i}_{layer_name}_{algo}.csv')

                if not rerun and os.path.exists(weight_path):
                    print(f"[INFO] {algo} on {layer_name} (Dataset {i + 1}) already exists. Loading...")
                    weight_df = pd.read_csv(weight_path, index_col=0)
                    metrics = calculate_metrics(weight_df, score_col="weight", label_col="is_causal", topk=20)
                    row = {
                        "Method": algo,
                        "Layer": layer_name,
                        "Dataset": i,
                        "Seed": seed,
                        "ScoreType": "weight",
                        **metrics
                    }
                    all_results.append(row)
                    continue

                print(f"[INFO] Evaluating {algo} on {layer_name} (Dataset {i + 1})...")
                func = algorithm_functions[algo]
                pred_dict = func(adata, output_dir)

                # Save and evaluate each step's result: step1, step2, all
                for step_key in ["step1", "step2", "all"]:
                    df = pred_dict[step_key]
                    assert "is_causal" in df.columns, f"[ERROR] Missing 'is_causal' in {step_key} result."

                    step_weight_path = os.path.join(output_dir, f'weights_dataset{i}_{step_key}_{algo}.csv')
                    df.to_csv(step_weight_path)
                    print(f"[INFO] Saved {step_key} weights to: {step_weight_path}")

                    score_type = f"weight_{step_key}" if step_key != "all" else "weight"
                    metrics = calculate_metrics(df, score_col="weight", label_col="is_causal", topk=20)
                    row = {
                        "Method": algo,
                        "Layer": step_key,
                        "Dataset": i,
                        "Seed": seed,
                        "ScoreType": score_type,
                        **metrics
                    }

                    if step_key == "all":
                        all_results.append(row)
                    else:
                        stepwise_results.append(row)

            else:
                layer_name = 'all'
                sub_adata = adata.copy()

                weight_path = os.path.join(output_dir, f'weights_dataset{i}_{layer_name}_{algo}.csv')
                if not rerun and os.path.exists(weight_path):
                # if algo != "DCI" and not rerun and os.path.exists(weight_path):

                    print(f"[INFO] {algo} on {layer_name} (Dataset {i + 1}) already exists. Loading...")
                    weight_df = pd.read_csv(weight_path, index_col=0)
                    metrics = calculate_metrics(weight_df, score_col="weight", label_col="is_causal", topk=20)
                    row = {
                        "Method": algo,
                        "Layer": layer_name,
                        "Dataset": i,
                        "Seed": seed,
                        "ScoreType": "weight",
                        **metrics
                    }
                    all_results.append(row)
                    continue

                print(f"[INFO] Evaluating {algo} on {layer_name} (Dataset {i + 1})...")
                func = algorithm_functions[algo]
                pred = func(sub_adata, output_dir)

                pred_dict = {"weight": pred} if not isinstance(pred, dict) else pred
                weight_df = pd.DataFrame(pred_dict, index=sub_adata.var_names)
                weight_df["is_causal"] = sub_adata.var["is_causal"].values

                weight_df.to_csv(weight_path)
                print(f"[INFO] Saved weights to: {weight_path}")

                metrics = calculate_metrics(weight_df, score_col="weight", label_col="is_causal", topk=20)
                row = {
                    "Method": algo,
                    "Layer": layer_name,
                    "Dataset": i,
                    "Seed": seed,
                    "ScoreType": "weight",
                    **metrics
                }
                all_results.append(row)

        del adata
        gc.collect()

    # Save final metrics
    df = pd.DataFrame(all_results)
    metrics_path = os.path.join(output_dir, 'Layerwise_Benchmark_Metrics.csv')
    df.to_csv(metrics_path, index=False)
    print(f"[INFO] Saved final evaluation to: {metrics_path}")

    # Draw plots
    plot_layerwise_metrics(
        df,
        output_dir,
        causal_strength=generation_args.get("causal_strength", 0.4),
        p_zero=generation_args.get("p_zero", 0.2)
    )

In [None]:
def compare_known_vs_unknown_layers(
    base_dir_unknown,
    base_dir_known,
    method="CauTrigger_SHAP",
    topk=20,
    n_datasets=10,
    output_path=None
):
    records = []
    for i in range(n_datasets):
        for u_layer, k_layer in [("step1", "layer1"), ("step2", "layer2")]:
            path_u = os.path.join(base_dir_unknown, f"weights_dataset{i}_{u_layer}_{method}.csv")
            path_k = os.path.join(base_dir_known, f"weights_dataset{i}_{k_layer}_{method}.csv")

            if not (os.path.exists(path_u) and os.path.exists(path_k)):
                print(f"[WARN] Missing file for dataset {i}, layer {u_layer} or {k_layer}")
                continue

            df_u = pd.read_csv(path_u, index_col=0)
            df_k = pd.read_csv(path_k, index_col=0)

            m_u = calculate_metrics(df_u, score_col="weight", label_col="is_causal", topk=topk)
            m_k = calculate_metrics(df_k, score_col="weight", label_col="is_causal", topk=topk)

            records.append({"Dataset": i, "Layer": u_layer, "Mode": "unknown", **m_u})
            records.append({"Dataset": i, "Layer": k_layer, "Mode": "known", **m_k})

    df_compare = pd.DataFrame(records)
    if output_path:
        df_compare.to_csv(output_path, index=False)
    return df_compare

def plot_hierarchy_comparison(df, output_dir, metric="AUROC", cs=None, p_zero=None):
    """
    Compare inferred (pseudo) vs prior (true) layer-level performance using boxplot.
    Outputs one plot per metric, with styling aligned to Nature Methods.
    """

    import matplotlib.pyplot as plt
    import seaborn as sns
    import os

    # === Global style ===
    sns.set_theme(style="white")
    plt.rcParams.update({
        "font.size": 13,
        "axes.labelsize": 13,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 12,
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "Arial"
    })

    # === Preprocess DataFrame ===
    df = df.copy()
    df["Layer"] = df["Layer"].replace({
        "step1": "Layer 1",
        "layer1": "Layer 1",
        "step2": "Layer 2",
        "layer2": "Layer 2"
    })
    df["Layer"] = pd.Categorical(df["Layer"], categories=["Layer 1", "Layer 2"], ordered=True)

    df["Setting"] = df["Mode"].replace({
        "unknown": "inferred",
        "known": "known"
    })
    df["Setting"] = pd.Categorical(df["Setting"], categories=["known", "inferred"], ordered=True)

    setting_palette = {
        "inferred": "#2C2C79",  # 深蓝紫
        "known": "#9DC3E6"      # 浅蓝
    }

    # === Plot ===
    plt.figure(figsize=(6, 4))
    ax = sns.boxplot(
        data=df,
        x="Layer",
        y=metric,
        hue="Setting",
        palette=setting_palette,
        width=0.6,
        fliersize=0,
        gap=0.15,  # hue box 之间留空
        linewidth=1
    )
    sns.stripplot(
        data=df,
        x="Layer",
        y=metric,
        hue="Setting",
        dodge=True,
        color="gray",
        alpha=0.5,
        size=4,
        jitter=0.2,
        edgecolor="none",
        linewidth=0,
        legend=False  # 防止 legend 重复
    )

    # === Adjust legend ===
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        handles[:2], labels[:2],
        title=None,
        frameon=False,
        loc="upper center",
        bbox_to_anchor=(0.5, 1.12),
        ncol=2,
        handlelength=1.5,
        columnspacing=1.5
    )

    # === Plot aesthetics ===
    ax.set_xlabel("")
    ax.set_ylabel(metric)
    ax.set_ylim(0, 1.05)
    sns.despine()
    plt.tight_layout()

    # === Save ===
    cs_tag = f"cs{int(cs * 100)}" if cs is not None else ""
    p_tag = f"p{p_zero}" if p_zero is not None else ""
    fname = f"{metric}_HierarchyComparison_{cs_tag}_{p_tag}".strip("_")
    plt.savefig(os.path.join(output_dir, f"{fname}.pdf"), bbox_inches='tight')
    plt.savefig(os.path.join(output_dir, f"{fname}.png"), bbox_inches='tight')
    plt.close()
    print(f"[INFO] Saved {fname}.pdf/png to: {output_dir}")


def plot_grid_hierarchy_comparison_boxplot(
    root_output_dir,
    causal_strength_list,
    p_zero_list,
    metric="AUROC",
    spurious_mode="semi_hrc",
    n_hidden=10,
    activation="linear",
    simulate_single_cell=True
):
    """
    汇总各参数组合下 compare_known_unknown.csv 文件，并画 3x3 九宫格图，比较先验 vs 推断的层级表现。
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import os

    sns.set_theme(style="white")
    plt.rcParams.update({
        "font.size": 12,
        "axes.labelsize": 14,
        "axes.titlesize": 16,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 14,
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": "Arial",
    })

    all_dfs = []

    for cs in causal_strength_list:
        for pz in p_zero_list:
            tag_parts = [
                "2L_unknown",
                spurious_mode,
                f"hidden{n_hidden}",
                activation,
                f"cs{int(cs * 100):02d}",
                f"p{pz}",
                "sc" if simulate_single_cell else "bulk",
            ]
            case_name = "_".join(tag_parts)
            csv_path = os.path.join(root_output_dir, case_name, "compare_known_unknown.csv")

            if os.path.exists(csv_path):
                df = pd.read_csv(csv_path)
                df["ParamCombo"] = f"Causal Strength = {cs}, Sparsity = {pz}"
                all_dfs.append(df)

    if not all_dfs:
        print("[WARN] No compare_known_unknown.csv files found.")
        return

    df = pd.concat(all_dfs, ignore_index=True)

    df["Layer"] = df["Layer"].replace({
        "step1": "Layer 1",
        "layer1": "Layer 1",
        "step2": "Layer 2",
        "layer2": "Layer 2"
    })
    df["Layer"] = pd.Categorical(df["Layer"], categories=["Layer 1", "Layer 2"], ordered=True)

    df["Setting"] = df["Mode"].replace({
        "unknown": "inferred",
        "known": "known"
    })
    df["Setting"] = pd.Categorical(df["Setting"], categories=["known", "inferred"], ordered=True)

    setting_palette = {
        "inferred": "#2C2C79",  # 深蓝紫
        "known": "#9DC3E6"      # 浅蓝
    }

    # === 绘图 ===
    fig, axes = plt.subplots(3, 3, figsize=(18, 12), sharey=True)
    param_combos = sorted(df['ParamCombo'].unique())

    for ax, combo in zip(axes.flatten(), param_combos):
        subdf = df[df['ParamCombo'] == combo]
        sns.boxplot(
            data=subdf,
            x="Layer",
            y=metric,
            hue="Setting",
            palette=setting_palette,
            width=0.5,
            fliersize=0,
            linewidth=1,
            gap=0.15,  # hue box 之间留空
            # dodge=0.6, # 调整箱线图的宽度
            ax=ax
        )
        sns.stripplot(
            data=subdf,
            x="Layer",
            y=metric,
            hue="Setting",
            dodge=True,
            color="gray",
            alpha=0.4,
            size=3,
            edgecolor="none",
            linewidth=0,
            ax=ax,
            legend=False
        )
        ax.set_title(combo, fontsize=14)
        ax.set_xlabel("")
        ax.tick_params(axis='x', labelsize=14)  # 控制 Layer 1 / 2 字体大小
        ax.set_ylabel(metric)
        ax.set_ylim(0, 1.05)
        ax.grid(False)
        ax.legend_.remove()

    fig.suptitle(f"Layer-wise {metric}: Known vs Inferred", fontsize=18, y=1.02)

    # 统一 legend
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(
        handles[:2], labels[:2],
        title=None,
        loc="upper center",
        bbox_to_anchor=(0.5, 1.0),
        ncol=2,
        frameon=False
    )

    plt.tight_layout()
    tag_parts = [
        "2L_unknown",
        spurious_mode,
        f"hidden{n_hidden}",
        activation,
        "sc" if simulate_single_cell else "bulk"
    ]
    base_tag = "_".join(tag_parts)
    fname = f"{base_tag}_PriorVsInferred_{metric}"

    fig.savefig(os.path.join(root_output_dir, f"{fname}.pdf"), bbox_inches='tight')
    fig.savefig(os.path.join(root_output_dir, f"{fname}.png"), bbox_inches='tight')
    plt.close(fig)
    print(f"[INFO] Saved: {fname}.pdf/.png")

In [None]:
def main():
    BASE_DIR = '/mnt/e/Project_Research/CauTrigger_Project/CauTrigger-master'
    case_dir = os.path.join(BASE_DIR, 'simulation')

    # ==== 固定参数 ====
    n_datasets = 10
    spurious_mode = 'semi_hrc'
    n_hidden = 10
    activation = 'linear'
    simulate_single_cell = True

    # algorithms = ['CauTrigger_SHAP', 'SVM', 'RF', 'MI']
    algorithms = ['CauTrigger_SHAP', 'GENIE3', 'SCRIBE', 'PC', 'VAEgrad','DCI', 'MI', 'RF', 'SVM']

    # ==== 你想扫的参数范围 ====
    causal_strength_list = [0.3, 0.4, 0.5, 0.6, 0.7]  # 因果强度：低、中、高
    p_zero_list = [0.1, 0.3, 0.5, 0.7]  # 稀疏度：低、中、高
    causal_strength_list = [0.3, 0.4, 0.5]  # 因果强度：低、中、高
    p_zero_list = [0.3, 0.5, 0.7]  # 稀疏度：低、中、高

    for causal_strength in causal_strength_list:
        for p_zero in p_zero_list:

            # ==== 动态生成一组参数 ====
            generation_args = dict(
                spurious_mode=spurious_mode,
                n_hidden=n_hidden,
                activation=activation,
                causal_strength=causal_strength,
                p_zero=p_zero,
                simulate_single_cell=simulate_single_cell,
            )

            # ==== 动态生成 case_name ====
            tag_parts = [
                "2L_unknown",
                spurious_mode,
                f"hidden{n_hidden}",
                activation,
                f"cs{int(causal_strength * 100):02d}",
                f"p{p_zero}",
                "sc" if simulate_single_cell else "bulk",
            ]
            case_name = "_".join(tag_parts)

            data_dir = os.path.join(case_dir, 'data', case_name)
            output_dir = os.path.join(case_dir, 'output', case_name)
            os.makedirs(data_dir, exist_ok=True)
            os.makedirs(output_dir, exist_ok=True)

            print(f"=== Running: {case_name} ===")

            run_benchmark(
                algorithms=algorithms,
                data_dir=data_dir,
                output_dir=output_dir,
                n_datasets=n_datasets,
                **generation_args
            )

            # === 比较层级识别 ===
            case_known = case_name.replace("2L_unknown", "2L_counts")
            base_dir_known = os.path.join(case_dir, "output", case_known)
            df_cmp = compare_known_vs_unknown_layers(
                base_dir_unknown=output_dir,
                base_dir_known=base_dir_known,
                method="CauTrigger_SHAP",
                topk=20,
                n_datasets=n_datasets,
                output_path=os.path.join(output_dir, "compare_known_unknown.csv")
            )
            plot_hierarchy_comparison(df_cmp, output_dir, metric="AUROC", cs=causal_strength, p_zero=p_zero)
            plot_hierarchy_comparison(df_cmp, output_dir, metric="AUPR", cs=causal_strength, p_zero=p_zero)

    # 最后生成汇总图（只处理当前这批组合）
    aggregate_output_root = os.path.join(case_dir, 'output')
    plot_aggregate_layerwise_metrics(
        root_output_dir=os.path.join(case_dir, 'output'),
        causal_strength_list=causal_strength_list,
        p_zero_list=p_zero_list,
        spurious_mode=spurious_mode,
        n_hidden=n_hidden,
        activation=activation,
        simulate_single_cell=simulate_single_cell
    )
    plot_grid_hierarchy_comparison_boxplot(
        root_output_dir=aggregate_output_root,
        causal_strength_list=causal_strength_list,
        p_zero_list=p_zero_list,
        metric="AUROC",  # 或 "AUPR"
        spurious_mode=spurious_mode,
        n_hidden=n_hidden,
        activation=activation,
        simulate_single_cell=simulate_single_cell
    )
    plot_grid_hierarchy_comparison_boxplot(
        root_output_dir=aggregate_output_root,
        causal_strength_list=causal_strength_list,
        p_zero_list=p_zero_list,
        metric="AUPR",
        spurious_mode=spurious_mode,
        n_hidden=n_hidden,
        activation=activation,
        simulate_single_cell=simulate_single_cell
    )