# Benchmark in Norman dataset (PCC)

The data is from Exploring genetic interaction manifolds constructed from rich single-cell phenotypes, https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE133344

## Import libraries and set working directory

In [1]:
import os
import pickle
import torch
import random
import logging
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['font.family'] = 'sans-serif'
logging.getLogger('matplotlib.font_manager').disabled = True
import numpy as np
import pandas as pd
import anndata
import scanpy as sc
import seaborn as sns
from sklearn.metrics import mean_squared_error as mse
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append('../')

BASE_DIR = '/your/working/directory'
case_path = os.path.join(BASE_DIR, 'BenchmarkNorman')
data_path = os.path.join(case_path, 'data')
output_path = os.path.join(case_path, 'output/')
os.makedirs(output_path, exist_ok=True)
data_name = 'perturb_processed.h5ad'
prior_network_name = 'trrust_rawdata.human.tsv'

In [None]:
def extract_gene_names(variable):
    import re
    gene_name = re.sub(r'ctrl|\+', '', variable)
    return gene_name

## Our method

We provide our result in pickle form file to save your time, if you want to run by your own, please comment out the first block codes below.

In [None]:
def run_ct(data_path, train_test_dict, norman_adata, TFs, nonTFs):
    
    if os.path.exists(os.path.join(data_path, 'CauTrigger_results_pcc_all.pickle')):
        with open(os.path.join(data_path, "CauTrigger_results_pcc_all.pickle"), "rb") as file:
            CauTrigger_pearsonr_dict = pickle.load(file)
        return CauTrigger_pearsonr_dict


    from CauTrigger.model import CauTrigger
    from CauTrigger.utils import set_seed
    down = nonTFs
    down = np.concatenate([TFs, nonTFs])
    CauTrigger_pearsonr_dict = {'single':[], 'combo_seen0':[], 'combo_seen1':[], 'combo_seen2':[]}

    norman_adata_TF = norman_adata[:, np.isin(norman_adata.var['gene_name'], TFs)]
    norman_adata_nonTF = norman_adata[:, np.isin(norman_adata.var['gene_name'], nonTFs)]
    norman_adata_TF_ctrl = norman_adata_TF[norman_adata_TF.obs['condition'] == 'ctrl', :]
    norman_adata_nonTF_ctrl = norman_adata_nonTF[norman_adata_nonTF.obs['condition'] == 'ctrl', :]
    norman_adata_for_CT_ctrl = norman_adata_TF_ctrl.copy()
    norman_adata_for_CT_ctrl.obs['labels'] = np.repeat(0, norman_adata_for_CT_ctrl.shape[0])
    norman_adata_for_CT_ctrl.obsm['X_down'] = norman_adata_nonTF_ctrl.X.copy()
    pert = np.concatenate([
    train_test_dict['single_train'],
    train_test_dict['combo_seen0_train'],
    train_test_dict['combo_seen1_train'],
    train_test_dict['combo_seen2_train']])
    norman_adata_TF_pert = norman_adata_TF[norman_adata_TF.obs['condition'].isin(pert), :]
    norman_adata_nonTF_pert = norman_adata_nonTF[norman_adata_nonTF.obs['condition'].isin(pert), :]
    norman_adata_for_CT_pert = norman_adata_TF_pert.copy()
    norman_adata_for_CT_pert.obs['labels'] = np.repeat(1, norman_adata_for_CT_pert.shape[0])
    norman_adata_for_CT_pert.obsm['X_down'] = norman_adata_nonTF_pert.X.copy()
    adata_for_train = anndata.concat([norman_adata_for_CT_ctrl, norman_adata_for_CT_pert])
    adata_for_train.X = adata_for_train.X.todense()
    adata_for_train.obsm['X_down'] = adata_for_train.obsm['X_down'].todense()
    set_seed(42)
    model = CauTrigger(
        adata_for_train,
        n_causal=2,
        n_latent=10,
        n_hidden=128,
        n_layers_encoder=0,
        n_layers_decoder=0,
        n_layers_dpd=0,
        dropout_rate_encoder=0.1,
        dropout_rate_decoder=0.1,
        dropout_rate_dpd=0.1,
        use_batch_norm='both',
        use_batch_norm_dpd=True,
        decoder_linear=True,
        dpd_linear=False,
        init_weight=None,
        init_thresh=0.0,
        update_down_weight=False,
        attention=False,
        att_mean=False,
    )
    weight_scheme = {'stage1': [0.1, 2.0, 2.0, 0.01, 0.5, 1.0, 0.0, 0.0, 0.0],
                     'stage2': [0.4, 2.0, 2.0, 0.1, 0.5, 0.5, 0.0, 0.01, 0.01],
                     'stage3': [0.8, 2.0, 2.0, 0.1, 0.1, 0.1, 1.0, 0.1, 0.1],
                     'stage4': [1, 2.0, 2.0, 0.01, 0.01, 0.01, 1.0, 1.0, 1.0]}
    if adata_for_train.shape[0] > 2000:
        max_epochs = 200
    else:
        max_epochs = 500
    model.train(max_epochs=max_epochs, im_factor=1, weight_scheme=weight_scheme)
    for single_test in train_test_dict['single_test']:
        pert = single_test + '+ctrl'
        norman_adata_for_CT_ctrl_test1 = norman_adata_for_CT_ctrl.copy()
        norman_adata_for_CT_ctrl_test1.X[:, norman_adata_for_CT_ctrl_test1.var['gene_name'] == single_test] = 2 * norman_adata_for_CT_ctrl.X.max()
        model.eval()
        with torch.no_grad():
            model_output = model.module.forward(torch.Tensor(norman_adata_for_CT_ctrl_test1.X.todense()).to('cuda:0'))
        pred_up = model_output['x_up_rec1'].cpu().numpy()
        pred_down = model_output['x_down_rec_alpha'].cpu().numpy()
        pred_all = np.concatenate([pred_up, pred_down], axis=1)
        pred_all = pd.DataFrame(pred_all, columns=np.concatenate(
            [norman_adata_TF_pert.var.gene_name, norman_adata_nonTF_pert.var.gene_name]))
        truth_up = norman_adata_TF[norman_adata_TF.obs['condition'] == pert, :].X.copy().todense()
        truth_down = norman_adata_nonTF[norman_adata_nonTF.obs['condition'] == pert, :].X.copy().todense()
        truth_all = np.concatenate([truth_up, truth_down], axis=1)
        truth_all = pd.DataFrame(truth_all, columns=np.concatenate(
            [norman_adata_TF_pert.var.gene_name, norman_adata_nonTF_pert.var.gene_name]))
        ctrl = np.concatenate([norman_adata_for_CT_ctrl.X.copy().todense(),
                               norman_adata_for_CT_ctrl.obsm['X_down'].copy().todense()], axis=1)
        ctrl = pd.DataFrame(ctrl, columns=np.concatenate(
            [norman_adata_TF_pert.var.gene_name, norman_adata_nonTF_pert.var.gene_name]))
        if pred_all.filter(down).shape[1] > 1:
            pred_deg = pred_all.filter(down).mean(0)
            truth_deg = truth_all.filter(down).mean(0)
            ctrl_deg = ctrl.filter(down).mean(0)
            res_pearsonr = pearsonr(pred_deg.values.flatten() - ctrl_deg.values.flatten(),
                                    truth_deg.values.flatten() - ctrl_deg.values.flatten())[0]
        else:
            res_pearsonr = np.nan
        CauTrigger_pearsonr_dict['single'].append(res_pearsonr)
    for key in ['combo_seen0', 'combo_seen1', 'combo_seen2']:
        for combo in train_test_dict[key]:
            pert = combo
            norman_adata_for_CT_ctrl_test1 = norman_adata_for_CT_ctrl.copy()
            norman_adata_for_CT_ctrl_test1.X[:, norman_adata_for_CT_ctrl_test1.var['gene_name'] == pert.split('+')[0]] = 2 * norman_adata_for_CT_ctrl.X.max()
            norman_adata_for_CT_ctrl_test1.X[:, norman_adata_for_CT_ctrl_test1.var['gene_name'] == pert.split('+')[1]] = 2 * norman_adata_for_CT_ctrl.X.max()
            model.eval()
            with torch.no_grad():
                model_output = model.module.forward(
                    torch.Tensor(norman_adata_for_CT_ctrl_test1.X.todense()).to('cuda:0'))
            pred_up = model_output['x_up_rec1'].cpu().numpy()
            pred_down = model_output['x_down_rec_alpha'].cpu().numpy()
            pred_all = np.concatenate([pred_up, pred_down], axis=1)
            pred_all = pd.DataFrame(pred_all, columns=np.concatenate(
                [norman_adata_TF_pert.var.gene_name, norman_adata_nonTF_pert.var.gene_name]))
            truth_up = norman_adata_TF[norman_adata_TF.obs['condition'] == pert, :].X.copy().todense()
            truth_down = norman_adata_nonTF[norman_adata_nonTF.obs['condition'] == pert, :].X.copy().todense()
            truth_all = np.concatenate([truth_up, truth_down], axis=1)
            truth_all = pd.DataFrame(truth_all, columns=np.concatenate(
                [norman_adata_TF_pert.var.gene_name, norman_adata_nonTF_pert.var.gene_name]))
            ctrl = np.concatenate([norman_adata_for_CT_ctrl.X.copy().todense(),
                                   norman_adata_for_CT_ctrl.obsm['X_down'].copy().todense()], axis=1)
            ctrl = pd.DataFrame(ctrl, columns=np.concatenate(
                [norman_adata_TF_pert.var.gene_name, norman_adata_nonTF_pert.var.gene_name]))
            if pred_all.filter(down).shape[1] > 1:
                pred_deg = pred_all.filter(down).mean(0)
                truth_deg = truth_all.filter(down).mean(0)
                ctrl_deg = ctrl.filter(down).mean(0)
                res_pearsonr = pearsonr(pred_deg.values.flatten() - ctrl_deg.values.flatten(),
                                        truth_deg.values.flatten() - ctrl_deg.values.flatten())[0]
            else:
                res_pearsonr = np.nan
            CauTrigger_pearsonr_dict[key].append(res_pearsonr)
    with open(os.path.join(data_path, "CauTrigger_results_pcc_all.pickle"), "wb") as file:
        pickle.dump(CauTrigger_pearsonr_dict, file)
    return CauTrigger_pearsonr_dict

## GEARS method

We provide our result in pickle form file to save your time, if you want to run by your own, please comment out the first block codes below.

Please prepare or install GEARS and we also provide our trained gears model to save your time, if you want to train by your own, please uncomment the lines below and comment out 'gears_model.load_pretrained' line.

In [None]:
def run_gears(data_path, norman_adata, TFs, nonTFs):
    
    if os.path.exists(os.path.join(data_path, 'GEARS_results_pcc_all.pickle')):
        with open(os.path.join(data_path, "GEARS_results_pcc_all.pickle"), "rb") as file:
            GEARS_pearsonr_dict = pickle.load(file)
        return GEARS_pearsonr_dict

    from gears import PertData, GEARS
    GEARS_pearsonr_dict = {'single':[], 'combo_seen0':[], 'combo_seen1':[], 'combo_seen2':[]}
    down = nonTFs
    down = np.concatenate([TFs, nonTFs])
    pert_data = PertData(data_path=data_path, gene_set_path=os.path.join(data_path, 'essential_all_data_pert_genes.pkl'))
    pert_data.load(data_path=data_path)
    pert_data.prepare_split(split='simulation', seed=1)
    pert_data.get_dataloader(batch_size=32, test_batch_size=128)
    gears_model = GEARS(pert_data, device='cuda:0', weight_bias_track=False)

    # gears_model.model_initialize(hidden_size=64)
    # gears_model.train(epochs=10, lr=1e-4)
    # os.makedirs(os.path.join(data_path, 'gears_model_pcc'), exist_ok=True)
    # gears_model.save_model(os.path.join(data_path, 'gears_model_pcc'))

    gears_model.load_pretrained(os.path.join(data_path, 'gears_model_pcc'))

    with open(os.path.join(data_path, 'splits', "data_simulation_1_0.75_subgroup.pkl"), "rb") as file:
        subgroups = pickle.load(file)
    test_subgroup = subgroups['test_subgroup']
    for unseen_single in test_subgroup['unseen_single']:
        gene = extract_gene_names(unseen_single)
        pert = gene+'+ctrl'
        pred = pd.DataFrame(gears_model.predict([[gene]])[gene]).T
        pred.columns = norman_adata.var.gene_name
        truth = pd.DataFrame(norman_adata[norman_adata.obs['condition'] == pert, :].X.toarray().mean(0)).T
        truth.columns = norman_adata.var.gene_name
        ctrl = pd.DataFrame(norman_adata[norman_adata.obs['condition'] == 'ctrl', :].X.toarray().mean(0)).T
        ctrl.columns = norman_adata.var.gene_name
        if pred.filter(down).shape[1] > 1:
            pred_deg = pred.filter(down)
            truth_deg = truth.filter(down)
            ctrl_deg = ctrl.filter(down)
            res_pearsonr = pearsonr(pred_deg.values.flatten() - ctrl_deg.values.flatten(),
                                    truth_deg.values.flatten() - ctrl_deg.values.flatten())[0]
        else:
            res_pearsonr = np.nan
        GEARS_pearsonr_dict['single'].append(res_pearsonr)
    for key in ['combo_seen0', 'combo_seen1', 'combo_seen2']:
        for combo in test_subgroup[key]:
            pert = combo
            pred = pd.DataFrame(gears_model.predict([combo.split('+')])['_'.join(combo.split('+'))]).T
            pred.columns = norman_adata.var.gene_name
            truth = pd.DataFrame(norman_adata[norman_adata.obs['condition'] == pert, :].X.toarray().mean(0)).T
            truth.columns = norman_adata.var.gene_name
            ctrl = pd.DataFrame(norman_adata[norman_adata.obs['condition'] == 'ctrl', :].X.toarray().mean(0)).T
            ctrl.columns = norman_adata.var.gene_name
            if pred.filter(down).shape[1] > 1:
                pred_deg = pred.filter(down)
                truth_deg = truth.filter(down)
                ctrl_deg = ctrl.filter(down)
                res_pearsonr = pearsonr(pred_deg.values.flatten() - ctrl_deg.values.flatten(),
                                        truth_deg.values.flatten() - ctrl_deg.values.flatten())[0]
            else:
                res_mse = np.nan
                res_pearsonr = np.nan
            GEARS_pearsonr_dict[key].append(res_pearsonr)
    with open(os.path.join(data_path, "GEARS_results_pcc_all.pickle"), "wb") as file:
        pickle.dump(GEARS_pearsonr_dict, file)
    return GEARS_pearsonr_dict

## CPA method

We provide our result in pickle form file to save your time, if you want to run by your own, please comment out the first block codes below.

Please prepare or install CPA, this part uses the intermediate file 'data_simulation_1_0.75_subgroup.pkl' from previous GEARS part, please insure this point.

In [None]:
def run_cpa(data_path, norman_adata, TFs, nonTFs):
    
    if os.path.exists(os.path.join(data_path, 'CPA_results_pcc_all.pickle')):
        with open(os.path.join(data_path, "CPA_results_pcc_all.pickle"), "rb") as file:
            cpa_pearsonr_dict = pickle.load(file)
        return cpa_pearsonr_dict

    down = nonTFs
    down = np.concatenate([TFs, nonTFs])
    cpa_pearsonr_dict = {'single':[], 'combo_seen0':[], 'combo_seen1':[], 'combo_seen2':[]}
    with open(os.path.join(data_path, 'splits', "data_simulation_1_0.75_subgroup.pkl"), "rb") as file:
        subgroups = pickle.load(file)
    test_subgroup = subgroups['test_subgroup']
    val_subgroup = subgroups['val_subgroup']

    def run_cpa_model(data_path, norman_adata, test_subgroup=test_subgroup, val_subgroup=val_subgroup):
        import cpa
        norman_adata_cpa = norman_adata.copy()
        norman_adata_cpa.obs['condition'] = pd.Categorical(norman_adata_cpa.obs['condition'])
        norman_adata_cpa.obs['split'] = 'train'
        norman_adata_cpa.obs.loc[
            norman_adata_cpa.obs['condition'].isin(np.concatenate(list(test_subgroup.values()))), 'split'] = 'valid'
        norman_adata_cpa.obs.loc[
            norman_adata_cpa.obs['condition'].isin(np.concatenate(list(val_subgroup.values()))), 'split'] = 'ood'
        cpa.CPA.setup_anndata(norman_adata_cpa,
                              perturbation_key='condition',
                              control_group='ctrl',
                              dosage_key='dose_val',
                              categorical_covariate_keys=['cell_type'],
                              is_count_data=False,
                              deg_uns_key='rank_genes_groups_cov',
                              deg_uns_cat_key='condition_name',
                              max_comb_len=2,
                              )
        model_params = {
            "n_latent": 32,
            "recon_loss": "nb",
            "doser_type": "linear",
            "n_hidden_encoder": 256,
            "n_layers_encoder": 4,
            "n_hidden_decoder": 256,
            "n_layers_decoder": 2,
            "use_batch_norm_encoder": True,
            "use_layer_norm_encoder": False,
            "use_batch_norm_decoder": False,
            "use_layer_norm_decoder": False,
            "dropout_rate_encoder": 0.2,
            "dropout_rate_decoder": 0.0,
            "variational": False,
            "seed": 8206,
        }
        trainer_params = {
            "n_epochs_kl_warmup": None,
            "n_epochs_adv_warmup": 50,
            "n_epochs_mixup_warmup": 10,
            "n_epochs_pretrain_ae": 10,
            "mixup_alpha": 0.1,
            "lr": 0.0001,
            "wd": 3.2170178270865573e-06,
            "adv_steps": 3,
            "reg_adv": 10.0,
            "pen_adv": 20.0,
            "adv_lr": 0.0001,
            "adv_wd": 7.051355554517135e-06,
            "n_layers_adv": 2,
            "n_hidden_adv": 128,
            "use_batch_norm_adv": True,
            "use_layer_norm_adv": False,
            "dropout_rate_adv": 0.3,
            "step_size_lr": 25,
            "do_clip_grad": False,
            "adv_loss": "cce",
            "gradient_clip_value": 5.0,
        }
        cpa_model = cpa.CPA(adata=norman_adata_cpa,
                            split_key='split',
                            train_split='train',
                            valid_split='valid',
                            test_split='ood',
                            **model_params,
                            )
        os.makedirs(os.path.join(data_path, 'CPA_model'), exist_ok=True)
        cpa_model.train(max_epochs=2000,
                        use_gpu=True,
                        batch_size=2048,
                        plan_kwargs=trainer_params,
                        early_stopping_patience=5,
                        check_val_every_n_epoch=5,
                        save_path=os.path.join(data_path, 'CPA_model', 'PCC'),
                        )
        norman_adata_cpa.layers['truth'] = norman_adata_cpa.X.copy()
        ctrl_adata = norman_adata_cpa[norman_adata_cpa.obs['condition'] == 'ctrl'].copy()
        norman_adata_cpa.X = ctrl_adata.X[np.random.choice(ctrl_adata.n_obs, size=norman_adata_cpa.n_obs, replace=True), :]
        cpa_model.predict(norman_adata_cpa, batch_size=2048)
        norman_adata_cpa.layers['CPA_pred'] = norman_adata_cpa.obsm['CPA_pred'].copy()
        subset_indices = norman_adata_cpa.obs['condition'].isin(np.concatenate(list(test_subgroup.values())))
        norman_adata_cpa_filter = norman_adata_cpa[subset_indices].copy()
        for layer in norman_adata_cpa.layers.keys():
            norman_adata_cpa_filter.layers[layer] = norman_adata_cpa.layers[layer][subset_indices, :]
        norman_adata_cpa_filter.uns.clear()
        norman_adata_cpa_filter.obsm.clear()
        norman_adata_cpa_filter.write_h5ad(os.path.join(data_path, 'CPA_result_adata.h5ad'))
        return norman_adata_cpa_filter

    norman_adata_cpa_filter = run_cpa_model(data_path, norman_adata)

    for unseen_single in test_subgroup['unseen_single']:
        gene = extract_gene_names(unseen_single)
        pert = gene + '+ctrl'
        pred_indices = norman_adata_cpa_filter.obs['condition'] == pert
        pred = pd.DataFrame(norman_adata_cpa_filter.layers['CPA_pred'][pred_indices, :].mean(0)).T
        pred.columns = norman_adata_cpa_filter.var.gene_name
        truth = pd.DataFrame(norman_adata_cpa_filter.layers['truth'][pred_indices, :].mean(0))
        truth.columns = norman_adata_cpa_filter.var.gene_name
        ctrl = pd.DataFrame(norman_adata_cpa_filter.X.mean(0))
        ctrl.columns = norman_adata_cpa_filter.var.gene_name
        if pred.filter(down).shape[1] > 1:
            pred_deg = pred.filter(down)
            truth_deg = truth.filter(down)
            ctrl_deg = ctrl.filter(down)
            res_pearsonr = pearsonr(pred_deg.values.flatten() - ctrl_deg.values.flatten(),
                                    truth_deg.values.flatten() - ctrl_deg.values.flatten())[0]
        else:
            res_pearsonr = np.nan
        cpa_pearsonr_dict['single'].append(res_pearsonr)
    for key in ['combo_seen0', 'combo_seen1', 'combo_seen2']:
        for combo in test_subgroup[key]:
            pert = combo
            pred_indices = norman_adata_cpa_filter.obs['condition'] == pert
            pred = pd.DataFrame(norman_adata_cpa_filter.layers['CPA_pred'][pred_indices, :].mean(0)).T
            pred.columns = norman_adata_cpa_filter.var.gene_name
            truth = pd.DataFrame(norman_adata_cpa_filter.layers['truth'][pred_indices, :].mean(0))
            truth.columns = norman_adata_cpa_filter.var.gene_name
            ctrl = pd.DataFrame(norman_adata_cpa_filter.X.mean(0))
            ctrl.columns = norman_adata_cpa_filter.var.gene_name
            if pred.filter(down).shape[1] > 1:
                pred_deg = pred.filter(down)
                truth_deg = truth.filter(down)
                ctrl_deg = ctrl.filter(down)
                res_pearsonr = pearsonr(pred_deg.values.flatten() - ctrl_deg.values.flatten(),
                                        truth_deg.values.flatten() - ctrl_deg.values.flatten())[0]
            else:
                res_pearsonr = np.nan
            cpa_pearsonr_dict[key].append(res_pearsonr)

    with open(os.path.join(data_path, "CPA_results_pcc_all.pickle"), "wb") as file:
        pickle.dump(cpa_pearsonr_dict, file)
    return cpa_pearsonr_dict

## Define function to compare

'RHOXF2BB+ctrl', 'LYL1+IER5L', 'ctrl+IER5L', 'KIAA1804+ctrl', 'IER5L+ctrl', 'RHOXF2BB+ZBTB25', 'RHOXF2BB+SET' are filtered because 'These perturbations are not in the GO graph and their perturbation can thus not be predicted' returned by GEARS, and you also can see this return in https://github.com/snap-stanford/GEARS/blob/master/demo/data_tutorial.ipynb

In [None]:
def run_benchmark(data_path, output_path, data_name, prior_network_name):
    norman_adata = sc.read_h5ad(os.path.join(data_path, data_name))
    norman_adata = norman_adata[~norman_adata.obs['condition'].isin(
        ['RHOXF2BB+ctrl', 'LYL1+IER5L', 'ctrl+IER5L', 'KIAA1804+ctrl', 'IER5L+ctrl', 'RHOXF2BB+ZBTB25',
         'RHOXF2BB+SET']), :]
    conditions = np.unique(norman_adata.obs['condition'])
    pairs = conditions
    valid_pairs = {
        (min(pair, reversed_pair), max(pair, reversed_pair))
        for pair in pairs if '+' in pair
        for reversed_pair in ['+'.join(reversed(pair.split('+')))]
        if reversed_pair in pairs
    }
    unique_valid_pairs = list(valid_pairs)
    df = norman_adata.obs.copy()
    condition_map = {pair[0]: pair[0] for pair in unique_valid_pairs}
    condition_map.update({pair[1]: pair[0] for pair in unique_valid_pairs})
    df['condition'] = df['condition'].replace(condition_map)
    df['condition_name'] = df['condition'].apply(lambda x: f'K562_{x}_1+1' if x != 'ctrl' else x)
    norman_adata.obs = df.copy()
    norman_adata.obs['condition'] = norman_adata.obs['condition'].astype(str)
    norman_adata.obs['condition'] = pd.Categorical(norman_adata.obs['condition'],
                                                   categories=norman_adata.obs['condition'].unique())
    conditions = np.unique(norman_adata.obs['condition'])
    conditions = np.delete(conditions, np.where(conditions == 'ctrl'))
    conditions_gene = []
    _ = [conditions_gene.extend(s.split('+')) for s in conditions]
    conditions_gene = list(set(conditions_gene))
    conditions_gene.remove('ctrl')
    Trrust = pd.read_table(os.path.join(data_path, prior_network_name), header=None)
    Trrust_TF = Trrust.iloc[:, 0].dropna().unique()
    Trrust_nonTF = np.setdiff1d(Trrust.iloc[:, 1].dropna().unique(), Trrust_TF)
    TFs = np.intersect1d(Trrust_TF, norman_adata.var['gene_name'])
    nonTFs = np.intersect1d(Trrust_nonTF, norman_adata.var['gene_name'])
    TFs_to_pert = np.intersect1d(TFs, conditions_gene)
    pairs_to_pert = [condition for condition in conditions if
                     condition.split('+')[0] in TFs_to_pert and condition.split('+')[1] in TFs_to_pert]
    np.random.seed(seed=1)
    single_train = np.random.choice(TFs_to_pert, int(len(TFs_to_pert) * 0.75), replace=False)
    single_test = np.setdiff1d(TFs_to_pert, single_train)
    pairs_to_pert_seen0 = [condition for condition in pairs_to_pert if
                     condition.split('+')[0] in single_test and condition.split('+')[1] in single_test]
    pairs_to_pert_seen1 = [
        condition for condition in pairs_to_pert
        if (condition.split('+')[0] in single_train and condition.split('+')[1] in single_test) or
           (condition.split('+')[0] in single_test and condition.split('+')[1] in single_train)
    ]
    pairs_to_pert_seen2 = [condition for condition in pairs_to_pert if
                        condition.split('+')[0] in single_train and condition.split('+')[1] in single_train]
    pairs_to_pert_seen0_train = np.random.choice(pairs_to_pert_seen0, 1, replace=False)
    pairs_to_pert_seen0_test = np.setdiff1d(pairs_to_pert_seen0, pairs_to_pert_seen0_train)
    pairs_to_pert_seen1_train = np.random.choice(pairs_to_pert_seen1, 10, replace=False)
    pairs_to_pert_seen1_test = np.setdiff1d(pairs_to_pert_seen1, pairs_to_pert_seen1_train)
    pairs_to_pert_seen2_train = np.random.choice(pairs_to_pert_seen2, 6, replace=False)
    pairs_to_pert_seen2_test = np.setdiff1d(pairs_to_pert_seen2, pairs_to_pert_seen2_train)
    train_test_dict = {'single_train': single_train, 'single_test': single_test,
                       'combo_seen0_train': pairs_to_pert_seen0_train, 'combo_seen0': pairs_to_pert_seen0_test,
                       'combo_seen1_train': pairs_to_pert_seen1_train, 'combo_seen1': pairs_to_pert_seen1_test,
                       'combo_seen2_train': pairs_to_pert_seen2_train, 'combo_seen2': pairs_to_pert_seen2_test}
    CauTrigger_pearsonr_dict = run_ct(data_path, train_test_dict, norman_adata, TFs, nonTFs)
    GEARS_pearsonr_dict= run_gears(data_path, norman_adata, TFs, nonTFs)
    cpa_pearsonr_dict = run_cpa(data_path, norman_adata, TFs, nonTFs)
    cpa_pearsonr_dict['all'] = np.concatenate(list(cpa_pearsonr_dict.values()))
    CauTrigger_pearsonr_dict['all'] = np.concatenate(list(CauTrigger_pearsonr_dict.values()))
    GEARS_pearsonr_dict['all'] = np.concatenate(list(GEARS_pearsonr_dict.values()))
    data = []
    methods = ['CauTrigger', 'GEARS', 'CPA']
    for method in methods:
        for key in cpa_pearsonr_dict.keys():
            if method == 'CPA':
                values = cpa_pearsonr_dict[key]
            elif method == 'CauTrigger':
                values = CauTrigger_pearsonr_dict[key]
            elif method == 'GEARS':
                values = GEARS_pearsonr_dict[key]
            data.extend([[key, method, value] for value in values])
    df = pd.DataFrame(data, columns=['Dataset', 'Method', 'Value'])
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df, x='Dataset', y='Value', hue='Method', ci='sd', capsize=0.1, palette=sns.color_palette()[3:6])
    plt.ylabel('Mean PCC')
    plt.xlabel('')
    plt.legend(title='Method')
    plt.ylim(0, 1)
    plt.savefig(os.path.join(output_path, 'Barplot_Mean_PCC.pdf'), bbox_inches='tight')
    plt.savefig(os.path.join(output_path, 'Barplot_Mean_PCC.png'), bbox_inches='tight')

## Run

In [None]:
run_benchmark(data_path, output_path, data_name, prior_network_name)