# Figures for Raw Count of ICD-9 & CPT 

In [30]:
import os, sys
import numpy as np
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt 
import glob 
import pickle 


from datetime import datetime
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all" 
from IPython.core.display import display, HTML    
display(HTML("<style>.container {width:90% !important; }</style>"))
%matplotlib inline
np.set_printoptions(precision=5, suppress=True) 

CURR_DATE = datetime.now().strftime('%Y-%m-%d')

In [31]:
sys.path.append('/dors/capra_lab/users/abraha1/projects/PTB_phenotyping/scripts/rand_forest_ptb_classification')
sys.path.append("/dors/capra_lab/users/abraha1/projects/PTB_phenotyping/scripts/rand_forest_ptb_classification/manuscript/0_helper_func")
from hyperparam_tune import validate_best_model
from manip_trained_models_funcs import unpack_input_data, upickle_xgbmodel, extract_train_df
from  shaply_funcs import create_descrip_dictionary, filter_shap

from collections import OrderedDict 
from cycler import cycler
import time as time 

In [32]:
%load_ext rpy2.ipython

In [41]:
ROOT_DATA_DIR = "/dors/capra_lab/users/abraha1/projects/PTB_phenotyping/results/ptb_predict_machine_learning/2019-01-16_xgboost_hyperopt_icd_cpt_raw_counts" 
OUTPUT_DIR = "/dors/capra_lab/users/abraha1/projects/PTB_phenotyping/scripts/rand_forest_ptb_classification/manuscript/counts_icd_cpt/feature_importance/"
ICD_CPT_DESCRIP_FILE ="/dors/capra_lab/users/abraha1/projects/PTB_phenotyping/data/ptb_predict_machine_learning/icd_cpt_descrip_mapping/descrip_master-col_names.txt"

In [34]:
# OUTPUT FILES 
STORED_SHAP_FILE = os.path.join(OUTPUT_DIR, '{}_shapley_icd_cpt_dicts.pickle'.format(DATE))

# Load and Format Raw Counts ICD & CPT 

## Load Data 

In [35]:
# INPUT  FILES 
icd_input = os.path.join(ROOT_DATA_DIR, 'input_data', 'input_data_all_icd9_count_subset-2019-01-25.tsv')
cpt_input = os.path.join(ROOT_DATA_DIR, 'input_data','input_data_all_cpt_count_subset-2019-01-26.tsv')
icd_cpt_input = os.path.join(ROOT_DATA_DIR, 'input_data','input_data_all_icd9_cpt_count_subset-2019-01-26.tsv')
# MODEL FILES 
icd_model = os.path.join(ROOT_DATA_DIR, 'best_model','best_xgb_model_all_icd9_count_subset-2019-01-25.pickle')
cpt_model = os.path.join(ROOT_DATA_DIR, 'best_model', 'best_xgb_model_all_cpt_count_subset-2019-01-26.pickle')
icd_cpt_model = os.path.join(ROOT_DATA_DIR, 'best_model', 'best_xgb_model_all_icd9_cpt_count_subset-2019-01-26.pickle')


In [37]:
dataset_dict = OrderedDict()
dataset_dict['icd'] =  {'input_file': icd_input, 'model_file': icd_model}
dataset_dict['cpt'] =   {'input_file': cpt_input, 'model_file': cpt_model}
dataset_dict['icd_cpt'] = {'input_file': icd_cpt_input, 'model_file': icd_cpt_model}

In [18]:
# output dictionary path with required data....
DATE = "2019-04-01"
STORED_DATA_FILE = os.path.join(OUTPUT_DIR,'{}_icd_cpt_datasets_dict.pickle'.format(DATE))

In [38]:
sns.set( style='whitegrid',  font_scale=1.5, rc={'figure.figsize':(8,8)} )
sns.set_style( {'axes.grid': True, 'axes.edgecolor': 'k', 'font.sans-serif': ['Arial'], 'grid.color': '#e1e1e1'})
plt.rc('axes', prop_cycle=(cycler('color', ['#1b9e77', '#d95f02', '#7570b3']) + cycler('linestyle', [':', ':', '-'])))
fsize=20
leg_fsize=14

# Feature Importance Using SHAP values

In [39]:
import shap
from textwrap import wrap

In [40]:
# load descriptions dictionary 
dsc_dict = create_descrip_dictionary(ICD_CPT_DESCRIP_FILE)

## calc shapely values

In [None]:
# for each dataset, calc shapley value and save it with its xgbmodel. 
all_shap_dict = {}
all_input_df_dict = {}
all_xgb_models_dict = {}
for this_label in dataset_dict.keys():
    shapley_vals_file = os.path.join(OUTPUT_DIR, '{}_{}_shapley_icd_cpt_dicts.pickle'.format(CURR_DATE,this_label))
    

    
    if os.path.isfile(shapley_vals_file):
        shap_file = open(shapley_vals_file, 'rb')
        all_shap_dict[this_label] = pickle.load(shap_file)
        print("loaded shap dict for {}".format(this_label))

    else: 
        print("calc shap values for {}".format(this_label))
        
        # prepare dataset for shapley calc
        this_input_file = dataset_dict[this_label]['input_file']
        this_model = dataset_dict[this_label]['model_file']
        X_train, y_train, X_test, y_test, xgb_model, this_input_data =  unpack_input_data(this_input_file, this_model)
        
        train_df = this_input_data.loc[this_input_data['partition']=='grid_cv'].copy()
        train_df.drop(['GRID','label','partition'], axis=1, inplace=True)
        
        # calc shap values 
        explainer = shap.TreeExplainer(xgb_model)
        shap_values = explainer.shap_values(train_df)
        
        # store shap vals 
        all_shap_dict[this_label] = shap_values
        all_input_df_dict[this_label] = this_input_data
        all_xgb_models_dict[this_label] = xgb_model 
        
        # saved pickeled shapely values 
        shapley_vals_file = os.path.join(OUTPUT_DIR, '{}_{}_shapley_icd_cpt_dicts.pickle'.format(CURR_DATE,this_label))
        pickle.dump(shap_values, open(shapley_vals_file, 'wb'))
        print("\tsaved shapley values")
        
        # saved pickeled input_df files 
        input_df_file = os.path.join(OUTPUT_DIR, '{}_{}_input_df.pickle'.format(CURR_DATE,this_label))
        pickle.dump(this_input_data, open(input_df_file, 'wb'))
        print("\tsaved input df")
       

calc shap values for icd
done loading input_data_all_icd9_count_subset-2019-01-25.tsv




	saved shapley values
	saved input df
calc shap values for cpt
done loading input_data_all_cpt_count_subset-2019-01-26.tsv




	saved shapley values
	saved input df
calc shap values for icd_cpt
done loading input_data_all_icd9_cpt_count_subset-2019-01-26.tsv




## write top n features w/ descirption 

In [None]:
# for each dataset, write a tsv with feature and description w/ mean +/- S D of abs shapley value 
for key, shap_vals in all_shap_dict.items(): 
    top_feats_df = filter_shap(shap_vals)
    # !TO DO WRITE FILE!

## mean abs(Shapley) per features

In [93]:
## calc mean shapely per dataset and save figure 
for key, shap_vals in all_shap_dict
    
    shap.summary_plot(shap_vals, extract_train_df(all_input_df_dict[key]), plot_type='bar', show=False)
#     _ = plt.savefig(os.path.join(output_dir,'{}_violin_shap_{}.pdf'.format(key, DATE)),  orientation='landscape')
#     plt.clf()

## violin plots of shapley values

In [39]:
## violin dot plot 
for key, items in shap_dict.items(): 
    print(key)
    shap.summary_plot(items['shapley'], items['train_data'], show=False)
    _ = plt.tight_layout()
#     _ = plt.title(key)
#     _ = plt.savefig(os.path.join(output_dir,'{}_violin_shap_{}.pdf'.format(key, DATE)),  orientation='landscape')
#     plt.clf()

icd
cpt




icd_cpt




<Figure size 576x684 with 0 Axes>