# Import modules and SPOKE node info

In [None]:
import numpy as np
import pandas as pd
from scipy import stats
from scipy.spatial.distance import cdist, pdist, squareform
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import subprocess
import time
import multiprocessing as mp
import matplotlib.pyplot as plt
import os
from collections import Counter
import seaborn as sns
import io
import requests
from sklearn.ensemble import RandomForestClassifier
import joblib

#
# load SPOKE nodes
node_info_df = pd.read_csv('node_info_df.tsv', sep='\t', header=0, index_col=False)
#
# Diseases for workshop
diseases = ['DOID:9778', 'DOID:1612', 'DOID:219']
disease_names = ['_'.join(node_info_df[node_info_df.Node==disease].Node_Name.values[0].split()) for disease in diseases]
#
# Disease PSEVs for workshop
psev_matrix = np.load('workshop_psevs.npy', allow_pickle=False)


def load_or_make_spokesig_mean_std_dist(node_info_df, example_cohort, diseases, disease_names, load_data=True):
    if load_data == False:
        # load spoke sigs
        spoke_sigs = np.load(int_spoke_sig_filename, allow_pickle=False)
        mean_sig = np.mean(spoke_sigs, axis=0)
        # get mean of SPOKEsigs
        np.save('mean_sig', mean_sig, allow_pickle=False)
        # get std of SPOKEsigs
        std_sig = np.std(spoke_sigs, axis=0)
        np.save('std_sig', std_sig, allow_pickle=False)
        # convert to z score (saving mean and std)
        spoke_sigs = np.nan_to_num((spoke_sigs-mean_sig)/std_sig)
        # create distance matrix
        patient_to_disease_dist = cdist(spoke_sigs, psev_matrix, metric='cosine')
        np.save('patient_to_disease_dist', patient_to_disease_dist, allow_pickle=False)
        # mean node value per disease
        mean_node_val_df = node_info_df.drop_duplicates()
        for disease, name in zip(diseases, disease_names):
            mean_node_val_df.loc[:,name] = np.mean(spoke_sigs[example_cohort[example_cohort.Disease==disease].Patient_Index.values], axis=0)
        del spoke_sigs
        mean_node_val_df.to_csv('mean_node_val_df.tsv', sep='\t', header=True, index=False)
    else:
        mean_sig = np.load('mean_sig.npy', allow_pickle=False)
        std_sig = np.load('std_sig.npy', allow_pickle=False)
        patient_to_disease_dist = np.load('patient_to_disease_dist.npy', allow_pickle=False)
    return mean_sig, std_sig, patient_to_disease_dist



# Load initial cohort

In [None]:
int_cohort_filename, int_spoke_sig_filename = 'example_cohort.tsv', 'example_spoke_sigs.npy'

example_cohort = pd.read_csv(int_cohort_filename, sep='\t', header=0, index_col=False)
example_cohort.loc[:,'label'] = example_cohort.Disease.map(dict(zip(diseases, np.arange(len(diseases)))))
example_cohort.head()

# Look at cohort demographics

In [None]:
for col in ['gender_source_value','race_source_value','ethnicity_source_value']:
    df = example_cohort[['Disease', col, 'person_id']].groupby(['Disease', col]).count().reset_index().rename(index=str, columns={'person_id':'Count'})
    ax=sns.barplot(x='Disease', y='Count', hue=col, data=df)
    plt.show()

# Look at cohort continuous variables

In [None]:
for col in ['age', 'OMOP_Count', 'SEP_Count']:
    ax=sns.boxplot(x='Disease', y='age', data=example_cohort)
    plt.show()

# Load initial cohort SPOKEsigs data

In [None]:
mean_sig, std_sig, patient_to_disease_dist = load_or_make_spokesig_mean_std_dist(node_info_df, example_cohort, diseases, disease_names, load_data=True)

# Compare patients to disease PSEVs

In [None]:
# find closest disease to patient
best_match = np.array(np.array(diseases)[np.argmin(patient_to_disease_dist, axis=1)])
print(np.sum(example_cohort.Disease.values==best_match))
#
# check if closest disease is correct
example_cohort.loc[:,'pred'] = best_match
example_cohort.loc[:,'match_correct'] = example_cohort.Disease.values == example_cohort.pred.values
match_stats_df = example_cohort[['Patient_Index', 'Disease', 'pred']].groupby(['Disease','pred']).count().reset_index()
match_stats_df

# Visualize cohort in 3d 
(based on distance to diseases)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
for disease, name in zip(diseases, disease_names):
    pats = example_cohort[example_cohort.Disease==disease].Patient_Index.values
    plt.scatter(patient_to_disease_dist[:,0][pats], patient_to_disease_dist[:,1][pats], patient_to_disease_dist[:,2][pats], label=name)
plt.legend()
plt.show()

# Load new patients from API

In [None]:
# fill new filenames
new_patient_info_filename = 'random_patient_info.tsv'
new_patient_spokesig_filename = 'random_patient_spokesigs.npy'

# load new patient data
new_cohort = pd.read_csv(new_patient_info_filename, sep='\t', header=0, index_col=False)
new_cohort.loc[:,'label'] = new_cohort.Disease.map(dict(zip(diseases, np.arange(len(diseases)))))
new_spoke_sigs = np.load(new_patient_spokesig_filename, allow_pickle=False)
# normalize new patients
new_spoke_sigs = (new_spoke_sigs-mean_sig)/std_sig
new_cohort.head()

# Compare new patients to disease PSEVs

In [None]:
new_patient_to_disease_dist = cdist(new_spoke_sigs, psev_matrix, metric='cosine')
#
best_match = np.array(np.array(diseases)[np.argmin(new_patient_to_disease_dist, axis=1)])
print(np.sum(new_cohort.Disease.values==best_match)/len(new_cohort))
#
new_cohort.loc[:,'pred'] = best_match
new_cohort.loc[:,'match_correct'] = new_cohort.Disease.values == new_cohort.pred.values
match_stats_df = new_cohort[['Patient_Index', 'Disease', 'pred']].groupby(['Disease','pred']).count().reset_index()
match_stats_df

# Load pre-trained random forest classifier (using initial cohort)

In [None]:
clf = joblib.load("random_forest_bc_cc_ibd.joblib")

# Test Random Forest Classifier using new patients

In [None]:
pred = clf.predict(new_spoke_sigs)
print(np.sum(new_cohort.sort_values('Patient_Index').label.values==pred)/len(new_cohort)) #0.9333333333333333

# Create node importance table

In [None]:
classifier_results = pd.read_csv('mean_node_val_df.tsv', sep='\t', header=0, index_col=False)
classifier_results.loc[:,'feature_importance'] = clf.feature_importances_


# Set a threshold for node importance

In [None]:
# select % threshold
thresh_percentile = 99
feature_thresh = np.percentile(classifier_results.feature_importance.values, thresh_percentile)
# select number of rows to show
top_n=50

# Look at key nodes for IBS

These are the top IBS nodes that pass feature importance threshold
sorted by their avg significance in the IBS cohort

In [None]:
col = 'irritable_bowel_syndrome'
classifier_results[classifier_results.feature_importance>feature_thresh].sort_values(col, ascending=False).head(top_n)

These are the top IBS nodes that pass feature importance and IBS avg sig thresholds. Nodes are sorted by their feature importance for the RF classifier

In [None]:
thresh = np.percentile(classifier_results[col].values, thresh_percentile)
classifier_results[(classifier_results.feature_importance>feature_thresh)&(classifier_results[col]>thresh)].sort_values('feature_importance', ascending=False).head(top_n)

# Look at key nodes for breast cancer

These are the top breast cancer nodes that pass feature importance threshold sorted by their avg significance in the breast cancer cohort

In [None]:
col = 'breast_cancer'
classifier_results[classifier_results.feature_importance>feature_thresh].sort_values(col, ascending=False).head(top_n)

These are the top breast cancer nodes that pass feature importance and breast cancer avg sig thresholds. Nodes are sorted by their feature importance for the RF classifier

In [None]:
thresh = np.percentile(classifier_results[col].values, thresh_percentile)
classifier_results[(classifier_results.feature_importance>feature_thresh)&(classifier_results[col]>thresh)].sort_values('feature_importance', ascending=False).head(top_n)

# Look at key nodes for colon cancer

These are the top colon cancer nodes that pass feature importance threshold sorted by their avg significance in the colon cancer cohort

In [None]:
col = 'colon_cancer'
classifier_results[classifier_results.feature_importance>feature_thresh].sort_values(col, ascending=False).head(top_n)

These are the top colon cancer nodes that pass feature importance and colon cancer avg sig thresholds. Nodes are sorted by their feature importance for the RF classifier

In [None]:
thresh = np.percentile(classifier_results[col].values, thresh_percentile)
classifier_results[(classifier_results.feature_importance>feature_thresh)&(classifier_results[col]>thresh)].sort_values('feature_importance', ascending=False).head(top_n)