In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
import shap
import sklearn.metrics as metrics

In [2]:
data_file = f"data/5_mr_50_cond/simulated_noNoise.txt"  #5 or 40 or 100

data = pd.read_csv(data_file, sep="\t", header=0)

N_genes = 100  # total no. of genes
N_TFs = N_genes

In [9]:
n_estimators=1000  # number of trees in the forest (as per GENIE paper)
criterion='squared_error'  # variance reduction equivalent
max_features = int(np.sqrt(N_genes-1)) # max no. of features to use in each tree (as per GENIE paper)
random_state = 42  # for reproducibility

In [10]:
# Normalize Expression data to unit-variance
data_n = StandardScaler(with_mean=False).fit_transform(data.to_numpy())

# Initialize matrices
W = np.zeros(shape=(N_genes,N_TFs))
Fscores = np.zeros(shape=(N_genes,))

for j in np.arange(0,N_genes):
    # read TF and gene expression data X and Gj
    X, Gj= data_n[:,:N_TFs], data_n[:,N_genes+j]
    
    # fit an RF model to predict gene expression from TF
    M_rf = RandomForestRegressor(criterion=criterion, n_estimators=n_estimators, max_features=max_features, random_state=random_state).fit(X,Gj)

    # train score
    Fscores[j] = M_rf.score(X,Gj)

    # Get the weights for all edges connecting TFs to gene j
    W[j,:] = M_rf.feature_importances_

    # # # look at feature importance based on SHAP values
    # # explainer = shap.TreeExplainer(M_rf)
    # # shap_values = explainer(X)

In [20]:
W_df = pd.melt(pd.DataFrame(np.abs(W)), var_name='TF_ID', value_name='Imp')

In [24]:
W_df_agg = W_df.groupby("TF_ID", sort=False, as_index=False).agg(['mean', 'std'])['Imp'].reset_index()
W_df_agg = W_df_agg.sort_values(by='mean', ascending= False)

W_df_agg

Unnamed: 0,index,mean,std
65,65,0.015657,0.016336
49,49,0.015130,0.018930
71,71,0.014680,0.017653
48,48,0.014655,0.017588
39,39,0.014575,0.017639
...,...,...,...
77,77,0.006061,0.006025
72,72,0.005905,0.007636
34,34,0.005623,0.007813
91,91,0.005515,0.008354


In [None]:
# Set the figure size
plt.figure(figsize=(15,5))

# plot a bar chart
ax = sns.barplot(x="TF_ID", y="Imp",
                 data=W_df, errorbar='sd',
                 order=W_df_agg.index,
                 capsize=.2, errwidth=1.5,
                 errcolor="black")
ax.set_xlabel('TF ID')
ax.set_ylabel('F. Importance')
ax.set_title("Global Ranking")
ax.tick_params(axis = 'x',labelrotation=90)