# Benchmark in Norman dataset (ACC)

The data is from 'Exploring genetic interaction manifolds constructed from rich single-cell phenotypes', https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE133344

## Import libraries and set working directory

In [3]:
import os
import pickle
import torch
import logging
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['font.family'] = 'sans-serif'
logging.getLogger('matplotlib.font_manager').disabled = True
import numpy as np
import pandas as pd
import anndata
import scanpy as sc
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append('../')

BASE_DIR = '/your/working/directory'
case_path = os.path.join(BASE_DIR, 'BenchmarkNorman/')
data_path = os.path.join(case_path, 'data/')
output_path = os.path.join(case_path, 'output/')
os.makedirs(output_path, exist_ok=True)
data_name = 'perturb_processed.h5ad'
prior_network_name = 'trrust_rawdata.human.tsv'

## True GI subtypes from GEARS

https://github.com/yhr91/GEARS_misc/blob/f88211870dfa89c38a2eedbd69ca1abd28a25f3c/gears/inference.py#L796

In [4]:
GIs = {
    'NEOMORPHIC': ['CBL+TGFBR2',
                   'KLF1+TGFBR2',
                   'MAP2K6+SPI1',
                   'SAMD1+TGFBR2',
                   'TGFBR2+C19orf26',
                   'TGFBR2+ETS2',
                   'CBL+UBASH3A',
                   'CEBPE+KLF1',
                   'DUSP9+MAPK1',
                   'FOSB+PTPN12',
                   'PLK4+STIL',
                   'PTPN12+OSR2',
                   'ZC3HAV1+CEBPE'],
    'ADDITIVE': ['BPGM+SAMD1',
                 'CEBPB+MAPK1',
                 'CEBPB+OSR2',
                 'DUSP9+PRTG',
                 'FOSB+OSR2',
                 'IRF1+SET',
                 'MAP2K3+ELMSAN1',
                 'MAP2K6+ELMSAN1',
                 'POU3F2+FOXL2',
                 'SAMD1+PTPN12',
                 'SAMD1+UBASH3B',
                 'SAMD1+ZBTB1',
                 'SGK1+TBX2',
                 'TBX3+TBX2',
                 'ZBTB10+SNAI1'],
    'EPISTASIS': ['AHR+KLF1',
                  'MAPK1+TGFBR2',
                  'TGFBR2+IGDCC3',
                  'TGFBR2+PRTG',
                  'UBASH3B+OSR2',
                  'DUSP9+ETS2',
                  'KLF1+CEBPA',
                  'MAP2K6+IKZF3',
                  'ZC3HAV1+CEBPA'],
    'REDUNDANT': ['CDKN1C+CDKN1A',
                  'MAP2K3+MAP2K6',
                  'CEBPB+CEBPA',
                  'CEBPE+CEBPA',
                  'CEBPE+SPI1',
                  'ETS2+MAPK1',
                  'FOSB+CEBPE',
                  'FOXA3+FOXA1'],
    'SYNERGY': ['CNN1+UBASH3A',
                'ETS2+MAP7D1',
                'FEV+CBFA2T3',
                'FEV+ISL2',
                'FEV+MAP7D1',
                'PTPN12+UBASH3A',
                'CBL+CNN1',
                'CBL+PTPN12',
                'CBL+PTPN9',
                'CBL+UBASH3B',
                'FOXA3+FOXL2',
                'FOXA3+HOXB9',
                'FOXL2+HOXB9',
                'UBASH3B+CNN1',
                'UBASH3B+PTPN12',
                'UBASH3B+PTPN9',
                'UBASH3B+ZBTB25',
                'AHR+FEV',
                'DUSP9+SNAI1',
                'FOXA1+FOXF1',
                'FOXA1+FOXL2',
                'FOXA1+HOXB9',
                'FOXF1+FOXL2',
                'FOXF1+HOXB9',
                'FOXL2+MEIS1',
                'IGDCC3+ZBTB25',
                'POU3F2+CBFA2T3',
                'PTPN12+ZBTB25',
                'SNAI1+DLX2',
                'SNAI1+UBASH3B'],
    'SUPPRESSOR': ['CEBPB+PTPN12',
                   'CEBPE+CNN1',
                   'CEBPE+PTPN12',
                   'CNN1+MAPK1',
                   'ETS2+CNN1',
                   'ETS2+IGDCC3',
                   'ETS2+PRTG',
                   'FOSB+UBASH3B',
                   'IGDCC3+MAPK1',
                   'LYL1+CEBPB',
                   'MAPK1+PRTG',
                   'PTPN12+SNAI1']
}

## Define functions to calculate GI scores and cutoffs of each subtype

GI scores include: Magnitude(mag), Model fit(corr_fit), Equality of contribution(eq_contr) and Similarity(dcor), they are from a trained TheilSenRegressor using expression after perturbation A, B and A+B.

In [5]:
def calculate_GIs(first_expr, second_expr, double_expr):
    from sklearn.linear_model import TheilSenRegressor
    from dcor import distance_correlation
    singles_expr = np.array([first_expr, second_expr]).T
    first_expr = first_expr.T
    second_expr = second_expr.T
    double_expr = double_expr.T
    results = {}
    results['ts'] = TheilSenRegressor(fit_intercept=False,
                                      max_subpopulation=1e5,
                                      max_iter=1000,
                                      random_state=1000)
    X = singles_expr
    y = double_expr
    results['ts'].fit(X, y.ravel())
    Zts = results['ts'].predict(X)
    results['c1'] = results['ts'].coef_[0]
    results['c2'] = results['ts'].coef_[1]
    results['mag'] = np.sqrt((results['c1'] ** 2 + results['c2'] ** 2))
    results['dcor'] = distance_correlation(singles_expr, double_expr)
    results['dcor_singles'] = distance_correlation(first_expr, second_expr)
    results['dcor_first'] = distance_correlation(first_expr, double_expr)
    results['dcor_second'] = distance_correlation(second_expr, double_expr)
    results['corr_fit'] = np.corrcoef(Zts.flatten(), double_expr.flatten())[0, 1]
    results['dominance'] = np.abs(np.log10(results['c1'] / results['c2']))
    results['eq_contr'] = np.min([results['dcor_first'], results['dcor_second']]) / np.max(
        [results['dcor_first'], results['dcor_second']])
    return results

In [6]:
def get_cutoff(data_path, norman_adata, TFs, nonTFs, GIs):
    norman_adata_nonTF = norman_adata[:, np.isin(norman_adata.var['gene_name'], nonTFs)]
    GI_cutoff_df = pd.DataFrame(columns=['key', 'value'])
    for key, values in GIs.items():
        temp_df = pd.DataFrame({'key': [key] * len(values), 'value': values})
        GI_cutoff_df = pd.concat([GI_cutoff_df, temp_df], ignore_index=True)
    GI_cutoff_df.set_index('value', inplace=True)
    GI_cutoff_df[['SSA', 'RED', 'NEO', 'EPI']] = None
    for GItype, ABarray in GIs.items():
        for AB in ABarray:
            ctrl = np.squeeze(
                np.array(norman_adata_nonTF[norman_adata_nonTF.obs['condition'] == 'ctrl', :].X.todense().mean(0)))
            truthA = np.squeeze(np.array(
                norman_adata_nonTF[norman_adata_nonTF.obs['condition'] == AB.split('+')[0] + '+ctrl',
                :].X.todense().mean(
                    0)))
            truthB = np.squeeze(np.array(
                norman_adata_nonTF[norman_adata_nonTF.obs['condition'] == AB.split('+')[1] + '+ctrl',
                :].X.todense().mean(
                    0)))
            truthAB = np.squeeze(
                np.array(norman_adata_nonTF[norman_adata_nonTF.obs['condition'] == AB, :].X.todense().mean(0)))
            results = calculate_GIs(truthA - ctrl, truthB - ctrl, truthAB - ctrl)
            GI_cutoff_df.at[AB, 'SSA'] = results['mag']
            GI_cutoff_df.at[AB, 'RED'] = results['dcor']
            GI_cutoff_df.at[AB, 'NEO'] = results['corr_fit']
            GI_cutoff_df.at[AB, 'EPI'] = results['eq_contr']
    thresh = {'SYNERGY': GI_cutoff_df.loc[GI_cutoff_df['key'] == 'SYNERGY', 'SSA'].min(),
              'SUPPRESSOR': GI_cutoff_df.loc[GI_cutoff_df['key'] == 'SUPPRESSOR', 'SSA'].max(),
              'ADDITIVE': [GI_cutoff_df.loc[GI_cutoff_df['key'] == 'ADDITIVE', 'SSA'].min(),
                           GI_cutoff_df.loc[GI_cutoff_df['key'] == 'ADDITIVE', 'SSA'].max()],
              'NEOMORPHIC': GI_cutoff_df.loc[GI_cutoff_df['key'] == 'NEOMORPHIC', 'NEO'].max(),
              'EPISTASIS': GI_cutoff_df.loc[GI_cutoff_df['key'] == 'EPISTASIS', 'EPI'].max(),
              'REDUNDANT': GI_cutoff_df.loc[GI_cutoff_df['key'] == 'REDUNDANT', 'RED'].min()}
    from collections import defaultdict
    GI_cutoff_dict = defaultdict(list)
    for idx, row in GI_cutoff_df.iterrows():
        GI_cutoff_dict[row['key']].append(idx)
    GI_cutoff_dict = dict(GI_cutoff_dict)
    return thresh, GI_cutoff_df, GI_cutoff_dict

## GEARS method

In [None]:
def run_gears(data_path, pairs_to_pert_GI_GT):
    if os.path.exists(os.path.join(data_path, 'GEARS_pred_pert_down_1012.pickle')):
        with open(os.path.join(data_path, "GEARS_pred_pert_down_1012.pickle"), "rb") as file:
            GEARS_pred_pert_down = pickle.load(file)
        return GEARS_pred_pert_down
        
    GEARS_pred_pert_down = {}
    from gears import PertData, GEARS
    for combo in pairs_to_pert_GI_GT.keys():
        pertA = combo.split('+')[0]
        pertB = combo.split('+')[1]
        pert_data = PertData(data_path=data_path, gene_set_path=os.path.join(data_path, 'essential_all_data_pert_genes.pkl'))
        pert_data.load(data_path=data_path)
        pert_data.prepare_split(test_perts=combo, seed=1)
        pert_data.get_dataloader(batch_size=32, test_batch_size=128)
        gears_model = GEARS(pert_data, device='cuda:0', weight_bias_track=False)

        # gears_model.model_initialize(hidden_size=64)
        # gears_model.train(epochs=20, lr=1e-4)
        # os.makedirs(os.path.join(data_path, 'gears_model_GI', combo), exist_ok=True)
        # gears_model.save_model(os.path.join(data_path, 'gears_model_GI', combo))

        gears_model.load_pretrained(os.path.join(data_path, 'gears_model_GI', combo))

        GEARS_pred_pert_down[combo] = {}
        GEARS_pred_pert_down[combo][combo] = list(gears_model.predict([combo.split('+')]).values())[0]
        GEARS_pred_pert_down[combo][pertA] = list(gears_model.predict([[pertA]]).values())[0]
        GEARS_pred_pert_down[combo][pertB] = list(gears_model.predict([[pertB]]).values())[0]
    with open(os.path.join(data_path, "GEARS_pred_pert_down_1012.pickle"), "wb") as file:
        pickle.dump(GEARS_pred_pert_down, file)
    return GEARS_pred_pert_down

## CPA method

In [None]:
def run_cpa(data_path, norman_adata, pairs_to_pert_GI_GT):
    if os.path.exists(os.path.join(data_path, 'cpa_pred_pert_down_1012.pickle')):
        with open(os.path.join(data_path, "cpa_pred_pert_down_1012.pickle"), "rb") as file:
            cpa_pred_pert_down = pickle.load(file)
        return cpa_pred_pert_down

    def run_cpa_model(data_path, norman_adata, pairs_to_pert_GI_GT):
        import cpa
        model_params = {
            "n_latent": 32,
            "recon_loss": "nb",
            "doser_type": "linear",
            "n_hidden_encoder": 256,
            "n_layers_encoder": 4,
            "n_hidden_decoder": 256,
            "n_layers_decoder": 2,
            "use_batch_norm_encoder": True,
            "use_layer_norm_encoder": False,
            "use_batch_norm_decoder": False,
            "use_layer_norm_decoder": False,
            "dropout_rate_encoder": 0.2,
            "dropout_rate_decoder": 0.0,
            "variational": False,
            "seed": 8206,
        }
        trainer_params = {
            "n_epochs_kl_warmup": None,
            "n_epochs_adv_warmup": 50,
            "n_epochs_mixup_warmup": 10,
            "n_epochs_pretrain_ae": 10,
            "mixup_alpha": 0.1,
            "lr": 0.0001,
            "wd": 3.2170178270865573e-06,
            "adv_steps": 3,
            "reg_adv": 10.0,
            "pen_adv": 20.0,
            "adv_lr": 0.0001,
            "adv_wd": 7.051355554517135e-06,
            "n_layers_adv": 2,
            "n_hidden_adv": 128,
            "use_batch_norm_adv": True,
            "use_layer_norm_adv": False,
            "dropout_rate_adv": 0.3,
            "step_size_lr": 25,
            "do_clip_grad": False,
            "adv_loss": "cce",
            "gradient_clip_value": 5.0,
        }
        cpa_pred_pert_down = {}
        combos = list(pairs_to_pert_GI_GT.keys())
        for combo in combos:
            norman_adata_cpa = norman_adata.copy()
            norman_adata_cpa.obs['condition'] = pd.Categorical(norman_adata_cpa.obs['condition'])
            norman_adata_cpa.obs['split'] = 'train'
            norman_adata_cpa.obs.loc[norman_adata_cpa.obs['condition'] == combo, 'split'] = 'ood'
            train_indices = norman_adata_cpa.obs[norman_adata_cpa.obs['split'] == 'train'].index
            validation_size = int(0.1 * len(train_indices))
            random_train_indices = np.random.choice(train_indices, size=validation_size, replace=False)
            norman_adata_cpa.obs.loc[random_train_indices, 'split'] = 'validation'
            cpa.CPA.setup_anndata(norman_adata_cpa,
                                  perturbation_key='condition',
                                  control_group='ctrl',
                                  dosage_key='dose_val',
                                  categorical_covariate_keys=['cell_type'],
                                  is_count_data=False,
                                  deg_uns_key='rank_genes_groups_cov',
                                  deg_uns_cat_key='condition_name',
                                  max_comb_len=2,
                                  )
            cpa_model = cpa.CPA(adata=norman_adata_cpa,
                        split_key='split',
                        train_split='train',
                        valid_split='valid',
                        test_split='ood',
                        **model_params,
                       )
            cpa_model.train(max_epochs=2000,
                    use_gpu=True,
                    batch_size=2048,
                    plan_kwargs=trainer_params,
                    early_stopping_patience=5,
                    check_val_every_n_epoch=5,
                    save_path=os.path.join(data_path, 'CPA_model', combo)
                   )
            norman_adata_cpa.layers['truth'] = norman_adata_cpa.X.copy()
            ctrl_adata = norman_adata_cpa[norman_adata_cpa.obs['condition'] == 'ctrl'].copy()
            norman_adata_cpa.X = ctrl_adata.X[np.random.choice(ctrl_adata.n_obs, size=norman_adata_cpa.n_obs, replace=True), :]
            cpa_model.predict(norman_adata_cpa, batch_size=2048)
            norman_adata_cpa.layers['CPA_pred'] = norman_adata_cpa.obsm['CPA_pred'].copy()
            subset_indices = norman_adata_cpa.obs['condition'] == combo
            cpa_pred = norman_adata_cpa.layers['CPA_pred'][subset_indices, :].mean(0)
            cpa_pred_pert_down[combo] = {}
            cpa_pred_pert_down[combo][combo] = cpa_pred
            pertA = combo.split('+')[0]
            pertB = combo.split('+')[1]
            cpa_pred_pert_down[combo][pertA] = norman_adata_cpa.layers['CPA_pred'][
                                               norman_adata_cpa.obs['condition'] == pertA + '+ctrl', :].mean(0)
            cpa_pred_pert_down[combo][pertB] = norman_adata_cpa.layers['CPA_pred'][
                                               norman_adata_cpa.obs['condition'] == pertB + '+ctrl', :].mean(0)

        with open(os.path.join(data_path, "cpa_pred_pert_down_1012.pickle"), "wb") as file:
            pickle.dump(cpa_pred_pert_down, file)
    run_cpa_model(data_path, norman_adata, pairs_to_pert_GI_GT)
    with open(os.path.join(data_path, "cpa_pred_pert_down_1012.pickle"), "rb") as file:
        cpa_pred_pert_down = pickle.load(file)
    return cpa_pred_pert_down

## Our method

In [None]:
def run_ct(data_path, norman_adata, pairs_to_pert_GI_GT, TFs, nonTFs):
    if os.path.exists(os.path.join(data_path, 'CauTrigger_pred_pert_1012.pickle')):
        with open(os.path.join(data_path, "CauTrigger_pred_pert_1012.pickle"), "rb") as file:
            CauTrigger_pred_pert_up, CauTrigger_pred_pert_down, CauTrigger_pred_pert_all = pickle.load(file)
        return CauTrigger_pred_pert_up, CauTrigger_pred_pert_down, CauTrigger_pred_pert_all
        
    from CauTrigger.model import CauTrigger
    from CauTrigger.utils import set_seed
    norman_adata_TF = norman_adata[:, np.isin(norman_adata.var['gene_name'], TFs)]
    norman_adata_nonTF = norman_adata[:, np.isin(norman_adata.var['gene_name'], nonTFs)]
    norman_adata_TF_ctrl = norman_adata_TF[norman_adata_TF.obs['condition'] == 'ctrl', :]
    norman_adata_nonTF_ctrl = norman_adata_nonTF[norman_adata_nonTF.obs['condition'] == 'ctrl', :]
    norman_adata_for_CT_ctrl = norman_adata_TF_ctrl.copy()
    norman_adata_for_CT_ctrl.obsm['X_down'] = norman_adata_nonTF_ctrl.X.copy()
    CauTrigger_pred_pert_up = {}
    CauTrigger_pred_pert_down = {}
    CauTrigger_pred_pert_all = {}
    pairs_to_pert = list(pairs_to_pert_GI_GT.keys())
    for pair_to_pert in pairs_to_pert:
        pertAB = pair_to_pert
        pertA = pair_to_pert.split('+')[0]
        pertB = pair_to_pert.split('+')[1]
        norman_adata_for_CT_pert = norman_adata_TF[
                                   norman_adata_TF.obs['condition'].isin([pertA + '+ctrl', pertB + '+ctrl']), :]
        norman_adata_for_CT_pert.obsm['X_down'] = norman_adata_nonTF[norman_adata_nonTF.obs['condition'].isin(
            [pertA + '+ctrl', pertB + '+ctrl']), :].X.todense()
        norman_adata_for_CT_ctrl1 = norman_adata_for_CT_ctrl[np.random.choice(norman_adata_for_CT_ctrl.n_obs, size=norman_adata_for_CT_pert.n_obs, replace=False), :]
        adata_for_train = anndata.concat([norman_adata_for_CT_ctrl1, norman_adata_for_CT_pert])
        adata_for_train.X = adata_for_train.X.todense()
        adata_for_train.obs['labels'] = np.repeat([0, 1],
                                                  [norman_adata_for_CT_ctrl1.n_obs, norman_adata_for_CT_pert.n_obs])
        adata_for_train.obsm['X_down'] = adata_for_train.obsm['X_down'].todense()
        set_seed(42)
        model = CauTrigger(
            adata_for_train,
            n_causal=2,
            n_latent=10,
            n_hidden=128,
            n_layers_encoder=0,
            n_layers_decoder=0,
            n_layers_dpd=0,
            dropout_rate_encoder=0.1,
            dropout_rate_decoder=0.1,
            dropout_rate_dpd=0.1,
            use_batch_norm='none',
            use_batch_norm_dpd=True,
            decoder_linear=True,
            dpd_linear=False,
            init_weight=None,
            init_thresh=0.0,
            update_down_weight=False,
            attention=False,
            att_mean=False,
        )
        weight_scheme = {'stage1': [0.1, 2.0, 2.0, 0.01, 0.5, 1.0, 0.0, 0.0, 0.0],
                         'stage2': [0.4, 2.0, 2.0, 0.1, 0.5, 0.5, 0.0, 0.01, 0.01],
                         'stage3': [0.8, 2.0, 2.0, 0.1, 0.1, 0.1, 1.0, 0.1, 0.1],
                         'stage4': [1, 2.0, 2.0, 0.01, 0.01, 0.01, 1.0, 1.0, 1.0]}
        if adata_for_train.shape[0] > 2000:
            max_epochs = 100
        else:
            max_epochs = 200
        model.train(max_epochs=max_epochs, im_factor=1, weight_scheme='norman')
        norman_adata_for_CT_ctrl_testAB = norman_adata_for_CT_ctrl1.copy()
        norman_adata_for_CT_ctrl_testAB.X[:,
        norman_adata_for_CT_ctrl_testAB.var['gene_name'] == pertA] = 2 * norman_adata_for_CT_ctrl1.X.max()
        norman_adata_for_CT_ctrl_testAB.X[:,
        norman_adata_for_CT_ctrl_testAB.var['gene_name'] == pertB] = 2 * norman_adata_for_CT_ctrl1.X.max()
        norman_adata_for_CT_ctrl_testA = norman_adata_for_CT_ctrl1.copy()
        norman_adata_for_CT_ctrl_testA.X[:,
        norman_adata_for_CT_ctrl_testA.var['gene_name'] == pertA] = 2 * norman_adata_for_CT_ctrl1.X.max()
        norman_adata_for_CT_ctrl_testB = norman_adata_for_CT_ctrl1.copy()
        norman_adata_for_CT_ctrl_testB.X[:,
        norman_adata_for_CT_ctrl_testB.var['gene_name'] == pertB] = 2 * norman_adata_for_CT_ctrl1.X.max()
        model.eval()
        with torch.no_grad():
            model_outputAB = model.module.forward(
                torch.Tensor(norman_adata_for_CT_ctrl_testAB.X.todense()).to('cuda:0'))
            model_outputA = model.module.forward(
                torch.Tensor(norman_adata_for_CT_ctrl_testA.X.todense()).to('cuda:0'))
            model_outputB = model.module.forward(
                torch.Tensor(norman_adata_for_CT_ctrl_testB.X.todense()).to('cuda:0'))
        pred_up = model_outputAB['x_up_rec1'].cpu().numpy().mean(0)
        pred_down = model_outputAB['x_down_rec_alpha'].cpu().numpy().mean(0)
        pred_all = np.concatenate([pred_up, pred_down], )
        CauTrigger_pred_pert_up[pair_to_pert] = {}
        CauTrigger_pred_pert_up[pair_to_pert][pair_to_pert] = pred_up
        CauTrigger_pred_pert_up[pair_to_pert][pertA] = model_outputA['x_up_rec1'].cpu().numpy().mean(0)
        CauTrigger_pred_pert_up[pair_to_pert][pertB] = model_outputB['x_up_rec1'].cpu().numpy().mean(0)
        CauTrigger_pred_pert_down[pair_to_pert] = {}
        CauTrigger_pred_pert_down[pair_to_pert][pair_to_pert] = pred_down
        CauTrigger_pred_pert_down[pair_to_pert][pertA] = model_outputA['x_down_rec_alpha'].cpu().numpy().mean(0)
        CauTrigger_pred_pert_down[pair_to_pert][pertB] = model_outputB['x_down_rec_alpha'].cpu().numpy().mean(0)
        CauTrigger_pred_pert_all[pair_to_pert] = {}
        CauTrigger_pred_pert_all[pair_to_pert][pair_to_pert] = pred_all
        CauTrigger_pred_pert_all[pair_to_pert][pertA] = np.concatenate([model_outputA['x_up_rec1'].cpu().numpy().mean(0), model_outputA['x_down_rec_alpha'].cpu().numpy().mean(0)])
        CauTrigger_pred_pert_all[pair_to_pert][pertB] = np.concatenate([model_outputB['x_up_rec1'].cpu().numpy().mean(0), model_outputB['x_down_rec_alpha'].cpu().numpy().mean(0)])
    with open(os.path.join(data_path, "CauTrigger_pred_pert_1012.pickle"), "wb") as file:
        pickle.dump((CauTrigger_pred_pert_up, CauTrigger_pred_pert_down, CauTrigger_pred_pert_all), file)
    return CauTrigger_pred_pert_all

## Run and compare

In [None]:
def run_benchmark(data_path, output_path, data_name, prior_network_name, GIs):
    norman_adata = sc.read_h5ad(os.path.join(data_path, data_name))
    norman_adata = norman_adata[~norman_adata.obs['condition'].isin(
        ['RHOXF2BB+ctrl', 'LYL1+IER5L', 'ctrl+IER5L', 'KIAA1804+ctrl', 'IER5L+ctrl', 'RHOXF2BB+ZBTB25',
         'RHOXF2BB+SET']), :]
    conditions = np.unique(norman_adata.obs['condition'])
    pairs = conditions
    valid_pairs = {
        (min(pair, reversed_pair), max(pair, reversed_pair))
        for pair in pairs if '+' in pair
        for reversed_pair in ['+'.join(reversed(pair.split('+')))]
        if reversed_pair in pairs
    }
    unique_valid_pairs = list(valid_pairs)
    df = norman_adata.obs.copy()
    condition_map = {pair[0]: pair[0] for pair in unique_valid_pairs}
    condition_map.update({pair[1]: pair[0] for pair in unique_valid_pairs})
    df['condition'] = df['condition'].replace(condition_map)
    df['condition_name'] = df['condition'].apply(lambda x: f'K562_{x}_1+1' if x != 'ctrl' else x)
    norman_adata.obs = df.copy()
    norman_adata.obs['condition'] = norman_adata.obs['condition'].astype(str)
    norman_adata.obs['condition'] = pd.Categorical(norman_adata.obs['condition'],
                                                   categories=norman_adata.obs['condition'].unique())
    conditions = np.unique(norman_adata.obs['condition'])
    conditions = np.delete(conditions, np.where(conditions == 'ctrl'))
    conditions_gene = []
    _ = [conditions_gene.extend(s.split('+')) for s in conditions]
    conditions_gene = list(set(conditions_gene))
    conditions_gene.remove('ctrl')
    Trrust = pd.read_table(os.path.join(data_path, prior_network_name), header=None)
    Trrust_TF = Trrust.iloc[:, 0].dropna().unique()
    Trrust_nonTF = np.setdiff1d(Trrust.iloc[:, 1].dropna().unique(), Trrust_TF)
    TFs = np.intersect1d(Trrust_TF, norman_adata.var['gene_name'])
    nonTFs = np.intersect1d(Trrust_nonTF, norman_adata.var['gene_name'])
    TFs_to_pert = np.intersect1d(TFs, conditions_gene)
    pairs_to_pert = [condition for condition in conditions if
                     condition.split('+')[0] in TFs_to_pert and condition.split('+')[1] in TFs_to_pert]
    pairs_to_pert_GI_GT = {val: key for key, values in GIs.items() for val in values if val in pairs_to_pert}
    reverse_pairs_to_pert_GI_GT = {}
    for pair, interaction in pairs_to_pert_GI_GT.items():
        if interaction not in reverse_pairs_to_pert_GI_GT:
            reverse_pairs_to_pert_GI_GT[interaction] = []
        reverse_pairs_to_pert_GI_GT[interaction].append(pair)
    CauTrigger_pred_pert_up, CauTrigger_pred_pert_down, CauTrigger_pred_pert_all = run_ct(data_path, norman_adata, pairs_to_pert_GI_GT, TFs, nonTFs)
    GEARS_pred_pert_all = run_gears(data_path, pairs_to_pert_GI_GT)
    cpa_pred_pert_all = run_cpa(data_path, norman_adata, pairs_to_pert_GI_GT)
    thresh, GI_cutoff_df, GI_cutoff_dict = get_cutoff(data_path, norman_adata, TFs, nonTFs, reverse_pairs_to_pert_GI_GT)
    GI_results_df = pd.DataFrame.from_dict(pairs_to_pert_GI_GT, orient='index', columns=['value'])
    GI_results_df[['SSA', 'RED', 'NEO', 'EPI']] = None
    for pair_to_pert in pairs_to_pert_GI_GT.keys():
        GI_results_df.at[pair_to_pert, 'SSA'] = GI_cutoff_df.loc[pair_to_pert, 'SSA']
        GI_results_df.at[pair_to_pert, 'RED'] = GI_cutoff_df.loc[pair_to_pert, 'RED']
        GI_results_df.at[pair_to_pert, 'NEO'] = GI_cutoff_df.loc[pair_to_pert, 'NEO']
        GI_results_df.at[pair_to_pert, 'EPI'] = GI_cutoff_df.loc[pair_to_pert, 'EPI']
    down_index = norman_adata.var.index[norman_adata.var.gene_name.isin(nonTFs)]
    ctrl = np.array(norman_adata[norman_adata.obs['condition'] == 'ctrl', down_index].X.mean(0)).flatten()
    GI_results_df[['SSA_ct', 'RED_ct', 'NEO_ct', 'EPI_ct']] = None
    for pair_to_pert in pairs_to_pert_GI_GT.keys():
        pertA = pair_to_pert.split('+')[0]
        pertB = pair_to_pert.split('+')[1]
        first = CauTrigger_pred_pert_down[pair_to_pert][pertA]
        second = CauTrigger_pred_pert_down[pair_to_pert][pertB]
        together = CauTrigger_pred_pert_down[pair_to_pert][pair_to_pert]
        results = calculate_GIs(first-ctrl, second-ctrl, together-ctrl)
        GI_results_df.at[pair_to_pert, 'SSA_ct'] = results['mag']
        GI_results_df.at[pair_to_pert, 'RED_ct'] = results['dcor']
        GI_results_df.at[pair_to_pert, 'NEO_ct'] = results['corr_fit']
        GI_results_df.at[pair_to_pert, 'EPI_ct'] = results['eq_contr']
    GI_results_df[['SSA_gears', 'RED_gears', 'NEO_gears', 'EPI_gears']] = None
    for pair_to_pert in pairs_to_pert_GI_GT.keys():
        pertA = pair_to_pert.split('+')[0]
        pertB = pair_to_pert.split('+')[1]
        first = GEARS_pred_pert_all[pair_to_pert][pertA][norman_adata.var.gene_name.isin(nonTFs)]
        second = GEARS_pred_pert_all[pair_to_pert][pertB][norman_adata.var.gene_name.isin(nonTFs)]
        together = GEARS_pred_pert_all[pair_to_pert][pair_to_pert][norman_adata.var.gene_name.isin(nonTFs)]
        results = calculate_GIs(first-ctrl, second-ctrl, together-ctrl)
        GI_results_df.at[pair_to_pert, 'SSA_gears'] = results['mag']
        GI_results_df.at[pair_to_pert, 'RED_gears'] = results['dcor']
        GI_results_df.at[pair_to_pert, 'NEO_gears'] = results['corr_fit']
        GI_results_df.at[pair_to_pert, 'EPI_gears'] = results['eq_contr']
    GI_results_df[['SSA_cpa', 'RED_cpa', 'NEO_cpa', 'EPI_cpa']] = None
    for pair_to_pert in pairs_to_pert_GI_GT.keys():
        pertA = pair_to_pert.split('+')[0]
        pertB = pair_to_pert.split('+')[1]
        first = cpa_pred_pert_all[pair_to_pert][pertA][norman_adata.var.gene_name.isin(nonTFs)]
        second = cpa_pred_pert_all[pair_to_pert][pertB][norman_adata.var.gene_name.isin(nonTFs)]
        together = cpa_pred_pert_all[pair_to_pert][pair_to_pert][norman_adata.var.gene_name.isin(nonTFs)]
        results = calculate_GIs(first-ctrl, second-ctrl, together-ctrl)
        GI_results_df.at[pair_to_pert, 'SSA_cpa'] = results['mag']
        GI_results_df.at[pair_to_pert, 'RED_cpa'] = results['dcor']
        GI_results_df.at[pair_to_pert, 'NEO_cpa'] = results['corr_fit']
        GI_results_df.at[pair_to_pert, 'EPI_cpa'] = results['eq_contr']
    GI_results_df.to_csv(os.path.join(output_path, 'GI_results_df.csv'))
    GI_results_df2 = GI_results_df.iloc[:, [0] + list(range(5, 17))]
    GI_results_df2.loc[GI_results_df2['value'] == 'SYNERGY', ['SSA_gears', 'SSA_ct', 'SSA_cpa']] = (
            GI_results_df2.loc[
                GI_results_df2['value'] == 'SYNERGY', ['SSA_gears', 'SSA_ct', 'SSA_cpa']] >= thresh[
                'SYNERGY']).astype(int)
    GI_results_df2.loc[GI_results_df2['value'] == 'SUPPRESSOR', ['SSA_gears', 'SSA_ct', 'SSA_cpa']] = (
            GI_results_df2.loc[
                GI_results_df2['value'] == 'SUPPRESSOR', ['SSA_gears', 'SSA_ct', 'SSA_cpa']] <= thresh[
                'SUPPRESSOR']).astype(int)
    GI_results_df2.loc[GI_results_df2['value'] == 'ADDITIVE', ['SSA_gears', 'SSA_ct', 'SSA_cpa']] = \
        (((GI_results_df2.loc[
               GI_results_df2['value'] == 'ADDITIVE', ['SSA_gears', 'SSA_ct', 'SSA_cpa']] >=
           thresh['ADDITIVE'][0]) & (GI_results_df2.loc[GI_results_df2['value'] == 'ADDITIVE',
        ['SSA_gears', 'SSA_ct', 'SSA_cpa']] <= thresh['ADDITIVE'][1])).astype(int))
    GI_results_df2.loc[GI_results_df2['value'].isin(['NEOMORPHIC', 'EPISTASIS', 'REDUNDANT']), ['SSA_gears', 'SSA_ct',
                                                                                                'SSA_cpa']] = 0
    GI_results_df2.loc[GI_results_df2['value'] == 'NEOMORPHIC', ['NEO_gears', 'NEO_ct', 'NEO_cpa']] = (
            GI_results_df2.loc[
                GI_results_df2['value'] == 'NEOMORPHIC', ['NEO_gears', 'NEO_ct', 'NEO_cpa']] <= thresh[
                'NEOMORPHIC']).astype(int)
    GI_results_df2.loc[GI_results_df2['value'] != 'NEOMORPHIC', ['NEO_gears', 'NEO_ct', 'NEO_cpa']] = 0
    GI_results_df2.loc[GI_results_df2['value'] == 'EPISTASIS', ['EPI_gears', 'EPI_ct', 'EPI_cpa']] = (
            GI_results_df2.loc[
                GI_results_df2['value'] == 'EPISTASIS', ['EPI_gears', 'EPI_ct', 'EPI_cpa']] <= thresh[
                'EPISTASIS']).astype(int)
    GI_results_df2.loc[GI_results_df2['value'] != 'EPISTASIS', ['EPI_gears', 'EPI_ct', 'EPI_cpa']] = 0
    GI_results_df2.loc[GI_results_df2['value'] == 'REDUNDANT', ['RED_gears', 'RED_ct', 'RED_cpa']] = (
            GI_results_df2.loc[
                GI_results_df2['value'] == 'REDUNDANT', ['RED_gears', 'RED_ct', 'RED_cpa']] >= thresh[
                'REDUNDANT']).astype(int)
    GI_results_df2.loc[GI_results_df2['value'] != 'REDUNDANT', ['RED_gears', 'RED_ct', 'RED_cpa']] = 0
    GI_table = GI_results_df2[GI_results_df2.value.isin(['SYNERGY', 'REDUNDANT'])]
    GI_table = GI_table.drop(GI_table.columns[0], axis=1)
    new_columns = []
    for i in range(0, len(GI_table.columns), 4):
        new_column_name = f'combined_{i // 4}'
        GI_table[new_column_name] = GI_table.iloc[:, i:i + 4].sum(axis=1)
        new_columns.append(new_column_name)
    GI_table.drop(GI_table.columns[:len(GI_table.columns) - len(new_columns)], axis=1, inplace=True)
    GI_table = GI_table.rename(columns={'combined_0': 'CauTrigger', 'combined_1': 'GEARS', 'combined_2': 'CPA'})
    GI_table['Type'] = GI_results_df2['value']
    total_row = GI_table[['CauTrigger', 'GEARS', 'CPA']].sum()/17
    GI_table = GI_table.groupby('Type').mean()
    GI_table['Type'] = GI_table.index
    GI_table = pd.melt(GI_table, id_vars=['Type'], var_name='Method', value_name='Score')
    plt.figure()
    sns.barplot(x='Type', y='Score', hue='Method', data=GI_table, palette=sns.color_palette()[3:6])
    plt.title('')
    plt.xlabel('')
    plt.ylabel('Accuracy')
    plt.savefig(os.path.join(output_path, f'plot_ACC_cutoff_down.png'), bbox_inches='tight')
    plt.savefig(os.path.join(output_path, f'plot_ACC_cutoff_down.pdf'), bbox_inches='tight')

In [None]:
run_benchmark(data_path, output_path, data_name, prior_network_name, GIs)