# Benchmark in simulation data

## Import libraries and set working directory

In [None]:
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from itertools import combinations
from sklearn.metrics import (roc_curve, auc, confusion_matrix, accuracy_score, matthews_corrcoef, f1_score,
                             precision_score, recall_score, precision_recall_curve)
from sklearn.feature_selection import f_classif, mutual_info_classif
from sklearn.linear_model import LogisticRegression, LassoCV, Ridge
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import normalize
from scipy.stats import pearsonr, spearmanr, f_oneway, norm
import pickle
import warnings
warnings.filterwarnings("ignore")
import sys
sys.path.append('../')
import logging
logging.getLogger('matplotlib.font_manager').disabled = True

from CauTrigger.utils import set_seed
from CauTrigger.model import CauTrigger
from CauTrigger.dataloaders import generate_synthetic_jersey

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['font.family'] = 'sans-serif'

## Set your directory

In [None]:
BASE_DIR = '/your/working/directory'
case_path = os.path.join(BASE_DIR, 'BenchmarkSimulation/')
data_path = os.path.join(case_path, 'data/')
output_path = os.path.join(case_path, 'output/')
os.makedirs(output_path, exist_ok=True)

## Define other methods and our model

In [9]:
# CauTrigger model
def run_model(adata, max_epochs, init_weight=None):
    set_seed(42)
    adata1 = adata.copy()
    model = CauTrigger(
        adata1,
        n_causal=2,
        n_latent=10,
        n_hidden=128,
        n_layers_encoder=0,
        n_layers_decoder=0,
        n_layers_dpd=0,
        dropout_rate_encoder=0.0,
        dropout_rate_decoder=0.0,
        dropout_rate_dpd=0.0,
        use_batch_norm='none',
        use_batch_norm_dpd=True,
        decoder_linear=True,
        dpd_linear=False,
        init_weight=init_weight,
        init_thresh=0.0,
        update_down_weight=False,
        attention=False,
        att_mean=False,
    )
    model.train(max_epochs=max_epochs, stage_training=True)
    weight_df_weight = model.get_up_feature_weights(normalize=True, method="Model", sort_by_weight=False)
    weight_df = pd.DataFrame({'weight_value': weight_df_weight[0]['weight'],})
    return weight_df


def gauss_ci_test(suff_stat, i, j, K):
    corr_matrix = suff_stat["C"]
    n_samples = suff_stat["n"]
    if len(K) == 0:
        r = corr_matrix[i, j]
    elif len(K) == 1:
        k = K[0]
        r = (corr_matrix[i, j] - corr_matrix[i, k] * corr_matrix[j, k]) / math.sqrt(
            (1 - corr_matrix[i, k] ** 2) * (1 - corr_matrix[j, k] ** 2)
        )
    else:
        sub_corr = corr_matrix[np.ix_([i, j] + K, [i, j] + K)]
        precision_matrix = np.linalg.pinv(sub_corr)
        r = (-1 * precision_matrix[0, 1]) / math.sqrt(
            abs(precision_matrix[0, 0] * precision_matrix[1, 1])
        )
    r = max(min(r, 0.99999), -0.99999)
    z = 0.5 * math.log1p((2 * r) / (1 - r))
    z_standard = z * math.sqrt(n_samples - len(K) - 3)
    p_value = 2 * (1 - norm.cdf(abs(z_standard)))
    return p_value


def get_neighbors(G, x, exclude_y):
    return [i for i, connected in enumerate(G[x]) if connected and i != exclude_y]


def skeleton(suff_stat, alpha):
    p_value_mat = np.zeros_like(suff_stat["C"])
    n_nodes = suff_stat["C"].shape[0]
    O = [[[] for _ in range(n_nodes)] for _ in range(n_nodes)]
    G = [[i != j for i in range(n_nodes)] for j in range(n_nodes)]
    pairs = [(i, j) for i in range(n_nodes) for j in range(i+1, n_nodes)]
    done = False
    l = 0
    while not done and any(any(row) for row in G):
        done = True
        for x, y in pairs:
            if G[x][y]:
                neighbors = get_neighbors(G, x, y)
                if len(neighbors) >= l:
                    done = False
                    for K in combinations(neighbors, l):
                        p_value = gauss_ci_test(suff_stat, x, y, list(K))
                        if p_value > p_value_mat[x][y]:
                            p_value_mat[x][y] = p_value_mat[y][x] = p_value
                        if p_value >= alpha:
                            G[x][y] = G[y][x] = False
                            O[x][y] = O[y][x] = list(K)
                            break
        l += 1
    return np.asarray(G, dtype=int), O, p_value_mat


def extend_cpdag(G, O):
    n_nodes = G.shape[0]
    def rule1(g):
        pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 0]
        for i, j in pairs:
            all_k = [k for k in range(n_nodes) if (g[j][k] == 1 and g[k][j] == 1) and (g[i][k] == 0 and g[k][i] == 0)]
            for k in all_k:
                g[j][k] = 1
                g[k][j] = 0
        return g
    def rule2(g):
        pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 1]
        for i, j in pairs:
            all_k = [k for k in range(n_nodes) if (g[i][k] == 1 and g[k][i] == 0) and (g[k][j] == 1 and g[j][k] == 0)]
            if len(all_k) > 0:
                g[i][j] = 1
                g[j][i] = 0
        return g
    def rule3(g):
        pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if g[i][j] == 1 and g[j][i] == 1]
        for i, j in pairs:
            all_k = [k for k in range(n_nodes) if (g[i][k] == 1 and g[k][i] == 1) and (g[k][j] == 1 and g[j][k] == 0)]
            if len(all_k) >= 2:
                for k1, k2 in combinations(all_k, 2):
                    if g[k1][k2] == 0 and g[k2][k1] == 0:
                        g[i][j] = 1
                        g[j][i] = 0
                        break
        return g

    pairs = [(i, j) for i in range(n_nodes) for j in range(n_nodes) if G[i][j] == 1]
    for x, y in sorted(pairs, key=lambda x: (x[1], x[0])):
        all_z = [z for z in range(n_nodes) if G[y][z] == 1 and z != x]
        for z in all_z:
            if G[x][z] == 0 and y not in O[x][z]:
                G[x][y] = G[z][y] = 1
                G[y][x] = G[y][z] = 0

    old_G = np.zeros((n_nodes, n_nodes))
    while not np.array_equal(old_G, G):
        old_G = G.copy()
        G = rule1(G)
        G = rule2(G)
        G = rule3(G)

    return np.array(G)


def pc(suff_stat, alpha=0.5, verbose=False):
    G, O, pvm = skeleton(suff_stat, alpha)
    cpdag = extend_cpdag(G, O)
    if verbose:
        print(cpdag)
    return cpdag, pvm


# PC algorithm model
def run_pc(X, y):
    alpha = 0.05
    data = pd.DataFrame(np.column_stack((X, y)))
    cpdag, pvm = pc(
        suff_stat={"C": data.corr().values, "n": data.shape[0]},
        alpha=alpha
    )
    pv = pvm[:-1, -1]
    return 1 - pv


# VAE model
def run_VAE(X, y, n_hidden, n_latent):
    import torch
    from torch import nn
    from torch.utils.data import TensorDataset, DataLoader
    n_features = X.shape[1]
    features = torch.tensor(X, dtype=torch.float32)
    labels = torch.tensor(y, dtype=torch.float32).view(-1, 1)
    dataset = TensorDataset(features, labels)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    class VAE(nn.Module):
        def __init__(self, num_features):
            super().__init__()
            self.encoder = nn.Sequential(
                nn.Linear(num_features, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, 2 * n_latent),
            )
            self.decoder = nn.Sequential(
                nn.Linear(n_latent, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, num_features),
            )
            self.DPD = nn.Sequential(
                nn.Linear(n_latent, n_hidden),
                nn.ReLU(),
                nn.Linear(n_hidden, 1),
                nn.Sigmoid(),
            )
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps * std + mu
        def forward(self, x):
            mu_logvar = self.encoder(x)
            mu = mu_logvar[:, :n_latent]
            logvar = mu_logvar[:, n_latent:]
            z = self.reparameterize(mu, logvar)
            y = self.DPD(z)
            reconstructed = self.decoder(z)
            return reconstructed, y, mu, logvar
    model = VAE(n_features)
    recon_criterion = nn.MSELoss()
    dpd_criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    model.train()
    losses = []
    re_losses = []
    kl_losses = []
    dpd_losses = []
    for epoch in range(200):
        for data, targets in dataloader:
            optimizer.zero_grad()
            recon_batch, y_dpd, mu, logvar = model(data)
            re_loss = recon_criterion(recon_batch, data)
            re_losses.append(re_loss.item())
            kl_loss = (
                -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / data.shape[0]
            )
            kl_losses.append(kl_loss.item())
            dpd_loss = dpd_criterion(y_dpd, targets)
            dpd_losses.append(dpd_loss.item())
            if epoch <= 100:
                loss = re_loss + kl_loss * 0.1 + dpd_loss * 0.1
            else:
                loss = re_loss + kl_loss * 0.1 + dpd_loss * 0.1
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    model.eval()
    features.requires_grad = True
    _, y_prob, _, _ = model(features)
    loss = dpd_criterion(y_prob, labels)
    loss.backward()
    grads = features.grad.abs()
    grad_features_importance = grads.mean(dim=0)
    grad_df = grad_features_importance.detach().numpy()
    return grad_df


# Other machine learning methods
def run_ml_methods(adata):
    set_seed(42)
    adata1 = adata.copy()
    X = adata1.X
    y = adata1.obs['labels'].values
    X = (X - X.mean(axis=0)) / X.std(axis=0)
    
    # SVM
    svm = SVC(kernel='linear')
    svm.fit(X, y)
    svm_importance = np.abs(svm.coef_)

    # Random Forest
    rf = RandomForestClassifier()
    rf.fit(X, y)
    rf_importance = rf.feature_importances_

    # LASSO
    lasso = LassoCV(cv=5)
    lasso.fit(X, y)
    lasso_importance = np.abs(lasso.coef_)

    # Mutual Information
    mi_importance = mutual_info_classif(X, y)

    # Logistic Regression
    logistic = LogisticRegression()
    logistic.fit(X, y)
    logistic_importance = np.abs(logistic.coef_)

    # Pearson correlation
    pearson_importance = np.abs([pearsonr(X[:, i], y)[0] for i in range(X.shape[1])])

    # ANOVA
    anova_importance = np.array([f_oneway(*(X[y == c][:, i] for c in np.unique(y)))[0] for i in range(X.shape[1])])

    # PC
    pc_importance = run_pc(X, y)

    # VAE
    vae_grad_importance = run_VAE(X, y, n_latent=10, n_hidden=64)

    feature_importance = np.vstack(
        (svm_importance, rf_importance, lasso_importance, mi_importance, logistic_importance, pearson_importance, anova_importance, pc_importance, vae_grad_importance))
    normalized_feature_importance = normalize(feature_importance, norm='l1')
    weight_df = pd.DataFrame(normalized_feature_importance.T,
                             columns=['SVM', 'Random Forest', 'LASSO', 'Mutual Information', 'Logistic Regression',
                                      'Pearson Correlation', 'ANOVA', 'PC Algorithm', 'VAE_Grad'])
    weight_df.index = adata1.var_names
    return weight_df


def select_features(df, threshold=None, topk=None):
    if isinstance(df, pd.Series):
        df = pd.DataFrame(df)
    selected_features = []
    for column in df.columns:
        sorted_column = df[column].sort_values(ascending=False)
        if threshold and not topk:
            cum_sum = sorted_column.cumsum()
            selected = (cum_sum <= threshold).astype(int)
            if selected.sum() == 0:
                selected[sorted_column.index[0]] = 1
        elif topk:
            top_k_features = sorted_column.nlargest(topk).index
            selected = pd.Series(0, index=df.index)
            selected[top_k_features] = 1
        else:
            raise ValueError('Please pass valid argument!')
        selected = pd.Series(selected, name=column)
        selected_features.append(selected)
    selected_df = pd.concat(selected_features, axis=1)
    selected_df.columns = df.columns
    return selected_df.reindex(df.index)


def generate_boxplot_csv(df, file_start_name='', subfolder_name='', set_ylim=True):
    plt.figure(figsize=(6, 6))
    sns.boxplot(data=df, palette='Set2')
    plt.xticks(rotation=45, ha='right')
    plt.xlabel('Methods')
    plt.ylabel(f'{file_start_name}')
    plt.title(f'{file_start_name} Boxplot for Different Methods')
    plt.tight_layout()
    if set_ylim:
        plt.ylim(0, 1)
    plt.savefig(os.path.join(subfolder_name, f'{file_start_name}_boxplot.pdf'), format='pdf')
    plt.savefig(os.path.join(subfolder_name, f'{file_start_name}_boxplot.png'), format='png')
    plt.close()
    df = df.T
    df['Mean'] = df.mean(axis=1)
    df['Max'] = df.max(axis=1)
    df['Min'] = df.min(axis=1)
    return df

## Define the benchmark function

In [10]:
def run_benchmark(
        n_dataset=10,
        n_samples=200,
        oversampling_factor=5,
        max_epochs=400,
        threshold=0.3,
        topk=0,
        noise_scale=0.1,
        causal_strength=5,
        is_linear=False,
        n_up_features=50,
        n_causal=10,
        n_down_features=150,
        init_weight=None,
        resdir=None,
) -> dict:
    model_res_all = []
    svm_result = []
    rf_result = []
    lasso_result = []
    mi_result = []
    lr_result = []
    pearson_result = []
    anova_result = []
    pc_result = []
    vae_grad_result = []
    adata_dict = {}
    for i in range(n_dataset):
        print(f"This is the {i + 1}/{n_dataset} dataset")
        set_seed(i)
        adata = generate_synthetic_jersey(n_samples=n_samples,
                                          oversampling_factor=oversampling_factor,
                                           n_up_features=n_up_features,
                                           n_down_features=n_down_features,
                                           n_causal=n_causal,
                                           n_hidden=5,
                                           n_latent=5,
                                           noise_scale=noise_scale,
                                           causal_strength=causal_strength,
                                           is_linear=is_linear)
        adata_dict[i] = adata
        model_res = run_model(adata, max_epochs)
        ml_res = run_ml_methods(adata)
        model_res_all.append(model_res['weight_value'])
        svm_result.append(ml_res['SVM'])
        rf_result.append(ml_res['Random Forest'])
        lasso_result.append(ml_res['LASSO'])
        mi_result.append(ml_res['Mutual Information'])
        lr_result.append(ml_res['Logistic Regression'])
        pearson_result.append(ml_res['Pearson Correlation'])
        anova_result.append(ml_res['ANOVA'])
        pc_result.append(ml_res['PC Algorithm'])
        vae_grad_result.append(ml_res['VAE_Grad'])
    result_dict = {}
    result_dict['CauTrigger'] = model_res_all
    result_dict['SVM'] = svm_result
    result_dict['RF'] = rf_result
    result_dict['MI'] = mi_result
    result_dict['LR'] = lr_result
    result_dict['LASSO'] = lasso_result
    result_dict['ANOVA'] = anova_result
    result_dict['PCC'] = pearson_result
    result_dict['PC'] = pc_result
    result_dict['VAE_Grad'] = vae_grad_result
    if topk:
        subfolder_name = f"{resdir}/noise_{noise_scale}_causal_{causal_strength}_topk_{topk}"
    else:
        subfolder_name = f"{resdir}/noise_{noise_scale}_causal_{causal_strength}_threshold_{threshold}"
    os.makedirs(subfolder_name, exist_ok=True)
    result_dict_df = pd.concat({k + '-' + str(i + 1): v[i] for k, v in result_dict.items() for i in range(len(v))}, axis=1)
    result_dict_df_predlabel = select_features(result_dict_df, threshold, topk)
    result_dict_df.to_csv(os.path.join(subfolder_name, 'result_weight.csv'))
    result_dict_df_predlabel.to_csv(os.path.join(subfolder_name, 'result_label.csv'))
    auroc_df = pd.DataFrame()
    aupr_df = pd.DataFrame()
    acc_df = pd.DataFrame()
    mcc_df = pd.DataFrame()
    precision_df = pd.DataFrame()
    specificity_df = pd.DataFrame()
    recall_df = pd.DataFrame()
    f1_df = pd.DataFrame()
    candidate_num_df = pd.DataFrame()
    n = len(next(iter(result_dict.values())))
    plt_row = int(np.ceil(n / 4))
    for i, (key, value) in enumerate(result_dict.items()):
        auroc_list = []
        aupr_list = []
        acc_list = []
        mcc_list = []
        precision_list = []
        specificity_list = []
        recall_list = []
        f1_list = []
        candidate_num_list = []
        for j, df in enumerate(value):
            df = pd.DataFrame({'weight': df})
            df['pred_label'] = select_features(df, threshold, topk)
            true_label = np.repeat([1, 0], [n_causal, n_up_features - n_causal])
            model_score = df['weight'].values
            pred_label = df['pred_label'].values
            # AUROC
            fpr, tpr, _1 = roc_curve(true_label, model_score)
            roc_auc = auc(fpr, tpr)
            auroc_list.append(roc_auc)
            # AUPR
            precision, recall, _ = precision_recall_curve(true_label, model_score)
            aupr = auc(recall, precision)
            aupr_list.append(aupr)
            # ACC (TP+TN)/(TP+TN+FP+FN)
            acc = accuracy_score(true_label, pred_label)
            acc_list.append(acc)
            # MCC (TP*TN-FP*FN)/np.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))
            mcc = matthews_corrcoef(true_label, pred_label)
            mcc_list.append(mcc)
            # Precision TP/(TP+FP)
            precision = precision_score(true_label, pred_label, pos_label=1)
            precision_list.append(precision)
            # Specificity TN/(TN+FP)
            cm = confusion_matrix(true_label, pred_label)
            TN = cm[0, 0]
            FP = cm[0, 1]
            specificity = TN / (TN + FP)
            specificity_list.append(specificity)
            # Recall TP/(TP+FN)
            recall = recall_score(true_label, pred_label, pos_label=1)
            recall_list.append(recall)
            # F1_score 2 * precision * recall / (precision + recall)
            f1 = f1_score(true_label, pred_label, pos_label=1)
            f1_list.append(f1)
            # Candidate TF num
            candidate_num = pred_label.sum()
            candidate_num_list.append(candidate_num)
        auroc_df[f'{key}'] = auroc_list
        aupr_df[f'{key}'] = aupr_list
        acc_df[f'{key}'] = acc_list
        mcc_df[f'{key}'] = mcc_list
        precision_df[f'{key}'] = precision_list
        specificity_df[f'{key}'] = specificity_list
        recall_df[f'{key}'] = recall_list
        f1_df[f'{key}'] = f1_list
        candidate_num_df[f'{key}'] = candidate_num_list
    auroc_df = generate_boxplot_csv(auroc_df, 'AUROC', subfolder_name=subfolder_name)
    aupr_df = generate_boxplot_csv(aupr_df, 'AUPR', subfolder_name=subfolder_name)
    acc_df = generate_boxplot_csv(acc_df, 'ACC', subfolder_name=subfolder_name)
    mcc_df = generate_boxplot_csv(mcc_df, 'MCC', subfolder_name=subfolder_name)
    precision_df = generate_boxplot_csv(precision_df, 'Precision', subfolder_name=subfolder_name)
    specificity_df = generate_boxplot_csv(specificity_df, 'Specificity', subfolder_name=subfolder_name)
    recall_df = generate_boxplot_csv(recall_df, 'Recall', subfolder_name=subfolder_name)
    f1_df = generate_boxplot_csv(f1_df, "F1_score", subfolder_name=subfolder_name)
    candidate_num_df = generate_boxplot_csv(candidate_num_df, "Candidate_num", subfolder_name=subfolder_name,
                                            set_ylim=False)
    score_dict = {
        'AUROC': auroc_df,
        'AUPR': aupr_df,
        'ACC': acc_df,
        'MCC': mcc_df,
        'Precision': precision_df,
        'Specificity': specificity_df,
        'Recall': recall_df,
        'F1_score': f1_df,
        'Candidate_num': candidate_num_df,
    }

    return score_dict, adata_dict

In [11]:
def benchmark_main(noises, causal_strengths, n_samples, n_dataset, threshold, topk, max_epochs,
                   n_up_features, n_causal, n_down_features, resdir, is_linear):
    benchmark_results = []
    adata_dict_all = {}
    for i, noise in enumerate(noises):
        for j, causal_strength in enumerate(causal_strengths):
            result, adata_dict = run_benchmark(n_dataset=n_dataset, n_samples=n_samples, max_epochs=max_epochs, threshold=threshold,
                                   topk=topk, noise_scale=noise,
                                   n_up_features=n_up_features, n_causal=n_causal,
                                   n_down_features=n_down_features,
                                   causal_strength=causal_strength, resdir=resdir, is_linear=is_linear)
            benchmark_results.append(result)
            adata_dict_all[(noise, causal_strength)] = adata_dict
    
    plt.rcParams.update({
        'font.size': 12,
        'pdf.fonttype': 42,
        'ps.fonttype': 42,
        'axes.labelsize': 14,
        'axes.titlesize': 14,
        'legend.fontsize': 12,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'axes.linewidth': 1.2,
        'grid.linewidth': 0.5,
        'font.sans-serif': ['Arial'],
        'font.family': 'sans-serif'
    })
    
    metrics_list = ['AUROC', 'AUPR', 'ACC', 'MCC', 'Precision', 'Specificity', 'Recall', 'F1_score']
    for metric in metrics_list:
        fig, axs = plt.subplots(len(noises), len(causal_strengths),
                                figsize=(len(causal_strengths) * 5, len(noises) * 5))
        if len(noises) == 1 and len(causal_strengths) == 1:
            axs = np.array([[axs]])
        elif len(noises) == 1 or len(causal_strengths) == 1:
            axs = np.expand_dims(axs, axis=0) if len(noises) == 1 else np.expand_dims(axs, axis=1)
        if metric != 'Candidate_num':
            for ax_row in axs:
                for ax in ax_row:
                    ax.set_ylim(0, 1)
        for i, res in enumerate(benchmark_results):
            x, y = divmod(i, len(causal_strengths))
            sns.boxplot(data=res[f'{metric}'].iloc[:, :n_dataset].T, ax=axs[x, y], palette='Set2')
            plt.setp(axs[x, y].get_xticklabels(), rotation=45, ha='right')
            axs[x, y].set_title(f'noise: {noises[x]} & causal strength: {causal_strengths[y] * 0.1}')

        plt.suptitle(f"{metric}")
        plt.tight_layout()
        plt.savefig(f'{resdir}/{metric}_multi_boxplot.pdf', format='pdf')
        plt.savefig(f'{resdir}/{metric}_multi_boxplot.png', format='png')
        plt.close()

    def create_mean_radar_chart(benchmark_results, metrics_list, resdir):
        method_names = ['CauTrigger', 'SVM', 'RF', 'MI', 'LR', 'LASSO', 'ANOVA', 'PCC', 'PC', 'VAE_Grad']
        mean_values = {method: [] for method in method_names}
    
        for metric in metrics_list:
            if metric == "Candidate_num":
                continue
            print(f"Processing metric: {metric}")
            for method_name in method_names:
                all_values = []
                for res in benchmark_results:
                    value = res[metric].loc[method_name, 'Mean']
                    all_values.append(value)
                mean_val = np.mean(all_values)
                mean_values[method_name].append(mean_val)
    
        categories = [metric for metric in metrics_list if metric != "Candidate_num"]
        num_categories = len(categories)
        angles = [n / float(num_categories) * 2 * np.pi for n in range(num_categories)]
        angles += angles[:1]
        fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
        cmap = plt.get_cmap("tab20")
        colors = [cmap(i) for i in np.linspace(0, 1, len(method_names))]
        for i, method_name in enumerate(method_names):
            values = mean_values[method_name]
            print(f"{method_name} values length: {len(values)}, num_categories: {num_categories}")
            if len(values) != num_categories:
                print(f"Skipping {method_name} due to mismatch in number of categories.")
                continue
            values += values[:1]
            ax.plot(angles, values, linewidth=2, linestyle='solid', label=method_name, color=colors[i])
            ax.fill(angles, values, color=colors[i], alpha=0.25)
    
        ax.set_rlabel_position(0)
        ax.yaxis.set_tick_params(labelsize=10)
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(categories, fontsize=12)
        ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), fontsize=12)
        plt.savefig(os.path.join(resdir, "mean_radar_chart.pdf"), format="pdf", bbox_inches='tight')
        plt.savefig(os.path.join(resdir, "mean_radar_chart.png"), format="png", bbox_inches='tight')
        plt.close()
        
    create_mean_radar_chart(benchmark_results, metrics_list, resdir)

    return benchmark_results, adata_dict_all

## Run benchmark

### Nonlinear scenario

In [None]:
noises = [0.01, 0.1, 1.0]
causal_strengths = [0.1, 0.2, 0.3]
topk = 10
is_linear = False
benchmark_results, adata_dict_all = benchmark_main(
    n_samples=200,
    n_dataset=3,
    threshold=None,
    topk=topk,
    max_epochs=300,
    n_up_features=100,
    n_causal=topk,
    n_down_features=200,
    noises=noises,
    causal_strengths=causal_strengths,
    resdir=output_path,
    is_linear=is_linear,
)

### Linear scenario

In [None]:
noises = [0.01, 0.1, 1.0]
causal_strengths = [0.1, 0.2, 0.3]
topk = 10
is_linear = True
benchmark_results, adata_dict_all = benchmark_main(
    n_samples=200,
    n_dataset=3,
    threshold=None,
    topk=topk,
    max_epochs=300,
    n_up_features=100,
    n_causal=topk,
    n_down_features=200,
    noises=noises,
    causal_strengths=causal_strengths,
    resdir=output_path,
    is_linear=is_linear,
)