In [1]:
"""
This script calculates and plots the feature importance for the multi-tissue model (Fig. 4C).
"""

In [None]:

import time
import matplotlib.pyplot as plt
import os
import pandas as pd
from sklearn import metrics
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_curve, auc

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import precision_recall_curve
import multiprocessing as mp
from time import gmtime, strftime
import shap


In [2]:
"---------------------------- Load Data -----------------------"

startTime = time.time() 
path = os.path.join('..', '..', 'Data', 'Transfer_Learning_Data_hg37.csv')
All_tissues_data= pd.read_csv(path)

path = os.path.join('..', '..', 'Data', 'Transfer_Learninig_Relevant_Columns_Names_Edited.csv')
Relevant_Cols_df = pd.read_csv(path)
overlap_cols = Relevant_Cols_df['Feature'].tolist()


In [3]:
"---------------------------- Columns Names -----------------------"

rename_dict = dict(zip(overlap_cols, Relevant_Cols_df['Feature Name'].tolist()))
# print(rename_dict)
All_tissues_data.rename(columns=rename_dict, inplace=True)

cols_names = Relevant_Cols_df['Feature Name'].tolist()
# print(cols_names)

relevant_cols = [x for x in cols_names if (x not in  ['Unnamed: 0', 'Tissue', 'Pathogenic_Mutation', 'CADD | GeneID y', 'Segway'])]
# print(relevant_cols)
cadd_cols = [c for c in relevant_cols if 'CADD' in c]
tissue_cols = [c for c in relevant_cols if 'CADD' not in c]

print('cadd_cols', len(cadd_cols))
print('tissue_cols', len(tissue_cols), tissue_cols)


cadd_cols 85
tissue_cols 5 ['Tissue | Expression (GTEx)', 'Tissue | Preferential expression', 'Tissue | Overexpressed interactors', 'Tissue | Paralog compensation', 'Tissue | Biological processes']


In [4]:
"---------------------------- Preprossecing Data --------------------"
duplicated_tissues = ['Skin - Sun Exposed', 'Heart - Atrial Appendage', 'brain-1', 'brain-0', 'brain-3', 'brain-2', 'Artery - Coronary', 'Artery - Tibial']
non_relevant_tissues = ['Adipose - Subcutaneous', 'Colon - Sigmoid', 'Breast - Mammary Tissue', 'Uterus', 'Adipose - Visceral', 'Esophagus - Gastroesophageal Junction', 'Esophagus - Mucosa', 'Thyroid']
duplicated_tissues = duplicated_tissues + non_relevant_tissues
Transfer_Data_set = All_tissues_data[~All_tissues_data['Tissue'].isin(duplicated_tissues)]

print(Transfer_Data_set)
tissues_list = Transfer_Data_set['Tissue'].unique().tolist()
print('tissues_list', len(tissues_list), tissues_list)

         Tissue | Expression (GTEx)  Tissue | Preferential expression  \
0                         87.689863                         -0.466328   
1                         87.689863                         -0.466328   
2                         87.689863                         -0.466328   
3                         87.689863                         -0.466328   
4                         87.689863                         -0.466328   
...                             ...                               ...   
1903099                    0.000000                         -0.515609   
1903100                   44.074610                         -0.461328   
1903101                   44.074610                         -0.461328   
1903102                   35.596696                          0.185557   
1903103                    1.288312                         -0.667247   

         _num_interactors  _num_interactors_dif_med  \
0                    53.0                       3.0   
1            

In [5]:
'-------------------- Synthetic Dataset --- ---------------------------'

Pathogenic = Transfer_Data_set[Transfer_Data_set['Pathogenic_Mutation'] == True]
counts = len(Pathogenic)
folds = 9
Non_pathogenic = Transfer_Data_set[Transfer_Data_set['Pathogenic_Mutation'] == False].sample(n=counts * folds, axis='index', random_state=1234)
Synthetic_Dataset = pd.concat([Pathogenic, Non_pathogenic])
print(Synthetic_Dataset)

         Tissue | Expression (GTEx)  Tissue | Preferential expression  \
0                         87.689863                         -0.466328   
1                         87.689863                         -0.466328   
2                         87.689863                         -0.466328   
3                         87.689863                         -0.466328   
4                         87.689863                         -0.466328   
...                             ...                               ...   
1660476                    0.242500                         -0.405867   
620598                     1.199479                          0.215144   
1651315                    4.595589                         -0.181366   
654988                    18.071676                         -0.415887   
1868692                    0.292592                         -0.706847   

         _num_interactors  _num_interactors_dif_med  \
0                    53.0                       3.0   
1            

In [6]:
'-------------------- Train Model--- ---------------------------'

model = RandomForestClassifier(random_state=1234)#, min_samples_leaf=100, n_estimators=10
model.fit(Synthetic_Dataset[relevant_cols], Synthetic_Dataset['Pathogenic_Mutation'])
print('@ Model trained',  time.time() - startTime)

@ Model trained 69.32734179496765


In [7]:
'-------------------- SHAP Importance ---------------------------'

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(Synthetic_Dataset[relevant_cols])
vals = np.abs(shap_values).mean(0)

print('@ Explainer created',  time.time() - startTime)

Feature_importance = pd.DataFrame(list(zip(Synthetic_Dataset[relevant_cols].columns, sum(vals))),columns=['col_name', 'feature_importance_vals'])
Feature_importance.sort_values(by=['feature_importance_vals'], ascending=False, inplace=True)
path = os.path.join('..', '..', 'Output', 'Transfer_learning_hg37_SHAP_Importance_Slim.csv')
Feature_importance.to_csv(path, index=False)
print('@ Feature Importance Saved',  time.time() - startTime)

shap.summary_plot(shap_values[1], Synthetic_Dataset[relevant_cols], show=False)
print('@ Plot created', startTime - time.time())
plt.tight_layout()
path = os.path.join('..', '..', 'Output', 'Transfer_learning_hg37_SHAP_dot2.pdf')
plt.savefig(path)
plt.close()



print('@ Finished',  time.time() - startTime)

shap.summary_plot(shap_values[1], Synthetic_Dataset[relevant_cols], show=False, plot_type='bar')
path = os.path.join('..', '..', 'Output', 'Transfer_learning_hg37_SHAP_bar2.pdf')
plt.tight_layout()

plt.savefig(path)
plt.close()

shap.summary_plot(shap_values[1], Synthetic_Dataset[relevant_cols], show=False, plot_type='violin')
path = os.path.join('..', '..', 'Output', 'Transfer_learning_hg37_SHAP_violin2.pdf')
plt.tight_layout()
plt.savefig(path)
plt.close()


@ Explainer created 133421.485517025
@ Feature Importance Saved 133421.63264775276
@ Plot created -133433.8365418911
@ Finished 133455.4794499874
