In [None]:
import import_ipynb
import ML_Pipeline_Overall

In [None]:
import fiber
from fiber.cohort import Cohort
from fiber.condition import Patient, MRNs
from fiber.condition import Diagnosis
from fiber.condition import Measurement, Encounter, Drug, TobaccoUse, AlcoholUse, MetaData
from fiber.storage import yaml as fiberyaml
import pandas as pd
import pyarrow.parquet as pq
import numpy as np
import os
from functools import reduce 
import math

In [None]:
import xgboost
import shap
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import numpy as np
from lightgbm import LGBMClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler 
from category_encoders import OneHotEncoder, TargetEncoder
from sklearn.metrics import roc_auc_score
from sklearn.compose import ColumnTransformer
import inspect
import json
import sys
from pydoc import locate

import requests

#get all the model config and utility functions from the other file
from ML_Pipeline_Overall import *

persist = False
window = 0
input_filename = 'All3_ML_pipeline_final.pkl'
output_path = '../plots/'
output_filename = 'Plots_first_try.pkl'

##################
class FriendlyNamesConverter:
    def rename_columns(self, df):
        replacements = {}
        for column in df.columns:
            replacements[column] = self.get(column)
        return replacements

    def get(self, feature):
        # does not support time window information inside feature name yet
        if feature.startswith(('age', 'gender', 'religion', 'race')):
            return feature.replace('_', ' ').replace('.', '|')

        split_name = feature.split('__')
        if len(split_name) > 1: 
            if split_name[1] in [
                i[0]
                for i in inspect.getmembers(
                    sys.modules['fiber.condition'],
                    inspect.isclass
                )
            ]:
                aggregation = split_name[0]
                split_name = split_name[1:]
            else:
                aggregation = None

            if len(split_name) == 3:
                class_name, context, code = split_name
                condition_class = locate(f'fiber.condition.{class_name}')
                description = self.get_description(condition_class, code, context)
                if  "Lipid panel" in description:
                    description = "Lipid panel"
            else:
                class_name, description = split_name

            if aggregation is not None: 
                return f'{class_name} | {description.capitalize()} ({aggregation})'
            else:
                return f'{class_name} | {description.capitalize()}'
        else:
            return feature

    def get_description(self, condition_class, code, context):
        return condition_class(
            code=code,
            context=context
        ).patients_per(
            condition_class.description_column
        )[
            condition_class.description_column.name.lower()
        ].iloc[0]

def get_column_names_from_ColumnTransformer(column_transformer):    
    col_name = []
    for transformer_in_columns in column_transformer.transformers_:#the last transformer is ColumnTransformer's 'remainder'
        raw_col_name = transformer_in_columns[2]
        if isinstance(transformer_in_columns[1],Pipeline): 
            transformer = transformer_in_columns[1].steps[-1][1]
        else:
            transformer = transformer_in_columns[1]
        try:
            names = transformer.get_feature_names()
        except AttributeError: # if no 'get_feature_names' function, use raw column name
            names = raw_col_name
        if isinstance(names,np.ndarray): # eg.
            col_name += names.tolist()
        elif isinstance(names,list):
            col_name += names    
        elif isinstance(names,str):
            col_name.append(names)
    return col_name


def get_shap_importanceplot(train_df, test_df):

    categorical_cols = [c for c in train_df.columns if train_df[c].dtype in [np.object] and c not in ['Complication']]
    numerical_cols = [c for c in train_df.columns if train_df[c].dtype in [np.float, np.int, 'uint8'] and c not in ['Complication']]
    train_df['Complication'] = pd.to_numeric(train_df['Complication'])
    test_df['Complication'] = pd.to_numeric(test_df['Complication'])
    print("length of column names are ::" + str(len(categorical_cols) + len(numerical_cols)))

    column_transformer = ColumnTransformer([
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(), categorical_cols),])
    
    lgb_classifier = LGBMClassifier(**lgb_param_calibrated)
    xgb_classifier = XGBClassifier(**xgb_param)
    catboost_classifier = CatBoostClassifier(**catboost_param)

    retro_train = train_df[train_df.columns.difference(['Complication'])]
    retro_label = train_df['Complication']
    pros_train = test_df[test_df.columns.difference(['Complication'])]
    pros_label = test_df['Complication']

    
    preprocessed_data = column_transformer.fit_transform(retro_train, retro_label)
    column_names = get_column_names_from_ColumnTransformer(column_transformer)

    preprocessed_data = pd.DataFrame(preprocessed_data, columns = column_names)

    #lgb_classifier.fit(preprocessed_data, train_df['Complication'])
    #xgb_classifier.fit(preprocessed_data, train_df['Complication'])
    
    catboost_classifier.fit(preprocessed_data, train_df['Complication'])
    
    #preprocessed_test_data = column_transformer.transform(pros_train)
    ##preprocessed_test_data = pd.DataFrame(preprocessed_test_data,column_names)
    
    shap_values = shap.TreeExplainer(catboost_classifier).shap_values(preprocessed_data)
    f = plt.figure()
    shap.summary_plot(shap_values, preprocessed_data, plot_type = 'dot')
    f.savefig((f'shap_explained_catboost_{input_filename}.png'), bbox_inches='tight', dpi=600)


if __name__ == "__main__":

    if (not os.path.isfile(os.path.join('/home/kiwitn01/master_thesis_hypertension-complications/Machine_Learning/plots', output_filename))) or (persist == True):
        print("renamed file not found, so creating one....")
        
        import fiber
        df =  pd.read_pickle(os.path.join('/home/kiwitn01/master_thesis_hypertension-complications/Case_Control_Cohort_Creation/For_ML_Pipeline/Split_2011/', input_filename))
        
        rename_dict = FriendlyNamesConverter().rename_columns(df)
        print(rename_dict)
        print(df.shape)
        
        df = df.rename(columns=rename_dict, errors="raise")
        df = df.loc[:,~df.columns.duplicated()]
        print(df.shape)
        df.to_pickle(os.path.join('/home/kiwitn01/master_thesis_hypertension-complications/Machine_Learning/plots', output_filename))
    
    else: 
        print("renamed file found, reading it from the drive....")
        df =  pd.read_pickle(os.path.join('/home/kiwitn01/master_thesis_hypertension-complications/Machine_Learning/plots', output_filename))

    
    df = df.dropna(axis=0, thresh = NA_removal_threshold)

    print(df.shape)
    print("final dataframe shape after dropping NAs" + str(df.shape))
    print(pd.crosstab(df.train_test, df.Complication))
    
    
    #df = df.reset_index('medical_record_number')
    train_df, test_df = train_test_split(df, test_df_control_ratio) 
    
    get_shap_importanceplot(train_df,test_df)
