In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from functools import partial
from multiprocessing import shared_memory
from multiprocessing.dummy import Pool
from sklearn.ensemble import RandomForestRegressor
import multiprocessing as mp
from itertools import chain, combinations
import sys
import os
from scipy import stats

from sklearn.preprocessing import Normalizer
from sklearn.ensemble import RandomForestRegressor

import mp_run

import concurrent.futures
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1

# styling:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use(['ggplot'])
sns.set_palette("deep")

In [2]:
perturbation_factor = 3
num_rf_predictors = 500

target_tf = 'AT2G46680'

induction_flag = -1
dataset = 'GSE88798'
mp_threads = 10
# if (len(sys.argv)>=3):
#     induction_flag = bool(sys.argv[1])
#     mp_threads = int(sys.argv[2])

In [3]:
tf_df = pd.read_csv('data/wrky_regulators.csv')
tf_list = tf_df['Gene']
common_genes = pd.read_csv('data/arabidopsis_common_genes.csv', index_col=0).index

ts_source_df = pd.read_csv('data/{}/ts_source_exp.csv'.format(dataset),  index_col=0).loc[common_genes]
ts_target_df = pd.read_csv('data/{}/ts_target_exp.csv'.format(dataset),  index_col=0).loc[common_genes]
tf_list = pd.Series(list(set(tf_list).intersection(set(ts_target_df.index))))



In [4]:
def perturb_exp(target_genes, induction_flag, output_path):
    result_list = []
    result_measure_list = []

    p_val_res_list = []

    for target_gene in tqdm(target_genes):
        # train_gene_index = tf_list[tf_list != target_gene]
        train_gene_index = tf_list
        
        ts_train_X = ts_source_df.T[train_gene_index]

        ts_train_y = ts_target_df.loc[target_gene]

        input_mean = ts_train_X.mean()
        input_std = ts_train_X.std()
        # two set of perturbations:
        # one with each WRKY and HB7 perturbed at the same time
        # another, i.e., the alt one with just each WRKY perturbed
        perturbation_input = input_mean.copy()
        perturbation_input_alt = input_mean.copy()
        perturbation_input[target_tf] += input_std[target_tf]*perturbation_factor
        perturbation_input = np.tile(perturbation_input.values, (len(input_mean),1))
        perturbation_input_alt = np.tile(perturbation_input_alt.values, (len(input_mean),1))
        for i, tf_name in enumerate(tf_list):
            if tf_name == target_tf: continue
            perturbation_input[i][i] += input_std[tf_name]*perturbation_factor
            perturbation_input_alt[i][i] += input_std[tf_name]*perturbation_factor
        func = partial(mp_run.regr_perturbation, ts_train_X, ts_train_y, perturbation_input, perturbation_input_alt)

        with mp.Pool() as pool:
            results = pool.map(func, range(num_rf_predictors))
        results = np.array(results)
        target_tf_index = np.where(tf_list == target_tf)[0][0]
        p_val_list = []
        for i, tf_name in enumerate(tf_list):
            if i == target_tf_index: 
                t_val, p_val = stats.ttest_rel(results[:, 0, target_tf_index], np.zeros(num_rf_predictors))
            else:
                t_val, p_val = stats.ttest_rel(results[:, 0, i], results[:, 0, target_tf_index] + results[:, 1, i])
            if (induction_flag*t_val > 0): p_val_list.append(p_val)
            else: p_val_list.append(1)
        p_val_res_list.append(p_val_list)
        # break

    tf_df.index = tf_df['Gene']
    out_df = pd.DataFrame(index=target_genes, columns=tf_df.loc[tf_list]['Symbol'], data=np.array(p_val_res_list))

    out_df.to_csv(output_path)

In [5]:
target_df = pd.read_csv('data/wrky_targets_neg.csv')
deg_genes = target_df['Gene']
target_genes = set(deg_genes).intersection(set(ts_target_df.index))

non_trivial_targets = []
for target in target_genes:
    if ts_target_df.loc[target].mean() != 0.0:
        non_trivial_targets.append(target)
target_genes = pd.Series(non_trivial_targets)

perturb_exp(target_genes, -1, './output/rf_wrky_inf_pval_{}_neg.csv'.format(dataset))


100%|██████████| 3/3 [00:25<00:00,  8.34s/it]
