In [None]:
import pandas as pd
from operator import index
import numpy as np
from tqdm import tqdm
from functools import partial
from multiprocessing import shared_memory
from multiprocessing.dummy import Pool
from sklearn.ensemble import RandomForestRegressor
import multiprocessing as mp
from itertools import chain, combinations


In [None]:
test_list = range(100)
out = [i for i in combinations(test_list, 3)]
perturbation_factor = 3


In [None]:
tf_df = pd.read_csv('data/Ath_TF_list.txt', sep='\t')
tf_list = tf_df['Gene_ID']

In [None]:
ts_df = pd.read_csv('data/GSE97500/expression.tsv', sep='\t', index_col=0)
meta_df = pd.read_csv('data/GSE97500/meta_data.tsv', sep='\t')
ts_exp_index = meta_df[meta_df['isTs']]
ts_exp_index_target =  ts_exp_index[ts_exp_index['is1stLast'] != 'f'].condName
ts_exp_index_source =  ts_exp_index[ts_exp_index['is1stLast'] != 'f'].prevCol
regulator_gene_index = ts_df.index
regulator_gene_index = pd.Series(list(set(tf_list).intersection(set(regulator_gene_index))))


In [105]:
ranking_df = pd.read_csv('output/GSE97500/AT1G14040_rankings.csv', index_col=0, names=['impact'])
ranking_df = ranking_df.sort_values('impact')

In [106]:
top_positive_influence_genes = ranking_df.tail(100).index
top_negative_influence_genes = ranking_df.head(100).index

data_mean = ts_df.T[top_positive_influence_genes].mean()
data_std = ts_df.T[top_positive_influence_genes].std()

regr = RandomForestRegressor(random_state=42, warm_start=True, n_estimators=300, n_jobs=20)

In [107]:
train_X = ts_df[ts_exp_index_source].T[top_positive_influence_genes]
train_y = ts_df[ts_exp_index_target].loc['AT1G14040']

In [108]:
regr = regr.fit(train_X, train_y)
base_prediction = regr.predict(np.array(data_mean).reshape(1,-1))[0]
y_std = ts_df.T.std()['AT1G14040']

In [109]:
perturbation_list = [i for i in combinations(top_positive_influence_genes, 3)] + [i for i in combinations(top_positive_influence_genes, 2)]

In [110]:
perturbation_result_list = []
perturbation_list_names = ['; '.join(perturbation_genes) for perturbation_genes in perturbation_list]

In [111]:
perturbed_input = []

for perturbation_genes in tqdm(perturbation_list):
    perturbation_input = data_mean.copy()
    for gene in perturbation_genes:
        perturbation_input[gene] += data_std[gene] * perturbation_factor
    perturbed_input.append(perturbation_input.values)
    # perturbation_prediction = regr.predict(np.array(perturbation_input).reshape(1,-1))[0]
    # perturbation_measure = (perturbation_prediction - base_prediction)/y_std
    # perturbation_result_list.append(perturbation_measure)
perturbed_input = np.vstack(perturbed_input)
perturbation_result_list = (regr.predict(perturbed_input) - base_prediction)/y_std

100%|██████████| 166650/166650 [00:09<00:00, 17388.48it/s]


In [112]:
np.array(perturbation_list_names)[np.argsort(perturbation_result_list)[::-1][:5]]

array(['AT4G22680; AT3G02940; AT5G56620',
       'AT3G02940; AT5G56620; AT5G40330',
       'AT3G02940; AT5G56620; AT1G12260',
       'AT3G02940; AT5G56620; AT1G18835',
       'AT4G22680; AT3G02940; AT5G40330'], dtype='<U31')

In [104]:
np.array(perturbation_result_list)[np.argsort(perturbation_result_list)[::-1][:5]]


array([0.37705886, 0.36532599, 0.34913108, 0.33936352, 0.33892069])