In [None]:
import os
import sys
import time
import random
import itertools
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, HTML
from functools import partial
from itertools import cycle, islice
from multiprocessing import Pool
import multiprocessing
import warnings
from sklearn import datasets, metrics, linear_model, svm
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split, StratifiedKFold, RandomizedSearchCV, GridSearchCV
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.datasets import (make_moons, make_circles, make_classification, 
                             make_blobs, make_checkerboard)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cluster import KMeans
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (RandomForestClassifier, AdaBoostClassifier, 
                              ExtraTreesClassifier, GradientBoostingClassifier, 
                              BaggingClassifier, VotingClassifier)
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier

from imblearn.over_sampling import SMOTE
import xgboost as xgb
from scipy import interp
from tableone import TableOne
import shap
from auto_shap.auto_shap import generate_shap_values

def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

In [None]:
df_final=pd.read_csv('20240727_AKI_data.csv',encoding='cp949')
print('total n')
print(len(df_final))

df_final_columns=list(df_final.columns)
df_final_columns2=[x for x in df_final_columns if '_avail' not in x]
df_final2=df_final[df_final_columns2]
y=list(df_final2['AKI_postop_7D'])
y2=list(df_final2['AKI_nonrecov_7D'])
y3=list(df_final2['AKI_sustained_48'])
y4=list(df_final2['AKI_sustained_72'])
ThakarScore=list(df_final2['ThakarScore'])
CKD_eGFR_stage=list(df_final2['CKD_eGFR_stage'])

print('AKI incidence')
print(np.sum(y))
print(np.sum(y)/len(df_final2))
print('')

print('AKI nonrecovery incidence')
print(np.sum(y2))
print(np.sum(y2)/len(df_final2))
print('')

print('AKI_sustained_48 incidence')
print(np.sum(y3))
print(np.sum(y3)/len(df_final2))
print('')

print('AKI_sustained_72 incidence')
print(np.sum(y4))
print(np.sum(y4)/len(df_final2))
print('')


print('AUROC, AUPRC of ThakarScore')
fpr, tpr, thresholds = roc_curve(y, ThakarScore)
auroc = auc(fpr, tpr)
precision, recall, _ = precision_recall_curve(y, ThakarScore)
auprc = auc(recall, precision)
print(auroc,auprc)

print('AUROC, AUPRC of CKD_eGFR_stage')
fpr, tpr, thresholds = roc_curve(y, CKD_eGFR_stage)
auroc = auc(fpr, tpr)
precision, recall, _ = precision_recall_curve(y, CKD_eGFR_stage)
auprc = auc(recall, precision)
print(auroc,auprc)

In [None]:
# test set 1 year
X_trainval=df_final2.iloc[:1693]
X_test=df_final2.iloc[1693:]
X_trainval.reset_index(inplace=True,drop=True)
X_test.reset_index(inplace=True,drop=True)

X_trainval=X_trainval[list(X_trainval.columns)[:-5]]
X_test=X_test[list(X_test.columns)[:-5]]

X_trainval = X_trainval.drop(columns=list(X_trainval.columns)[-11:-8])
X_test = X_test.drop(columns=list(X_test.columns)[-11:-8])

yy_trainval=y[:1693]
yy_test=y[1693:]

yy2_trainval=y2[:1693]
yy2_test=y2[1693:]

yy3_trainval=y3[:1693]
yy3_test=y3[1693:]

yy4_trainval=y4[:1693]
yy4_test=y4[1693:]

In [None]:
X_trainvaltest=pd.concat([X_trainval,X_test])
y_trainvaltest=yy4_trainval+yy4_test

# scenario columns

In [None]:
columns=list(X_trainval.columns)

input_model0_1=columns[:34]
input_model1_1=columns[34:40]
input_model0_2=columns[40:62]
input_model1_2=columns[62:87]
etccols=columns[87:]
input_model2_1=columns[155:-2]

scenario1=[]
scenario1.append('baseline_rSO2')
scenario1.append('MAP_mean_list')
scenario1.append('MAP_CV_list')
scenario1.append('MAP_ARV_list')
scenario1.append('MAP_duration_65_list')
scenario1.append('MAP_duration_100_list')
scenario1.append('MAP_auc_65_list')
scenario1.append('MAP_auc_100_list')
scenario1.append('PP_mean_list')
scenario1.append('PP_duration_60_list')
scenario1.append('PP_auc_60_list')
scenario1.append('CVP_mean_list')
scenario1.append('CVP_duration_12_list')
scenario1.append('CVP_auc_12_list')
scenario1.append('CI_mean_list')
scenario1.append('CI_duration_2_list')
scenario1.append('CI_auc_2_list')

scenario2=[]
scenario2.append('baseline_rSO2')
scenario2.append('NEW_MAP_preCPB_CV_list')
scenario2.append('NEW_MAP_preCPB_ARV_list')
scenario2.append('NEW_PP_preCPB_mean_list')
scenario2.append('NEW_PP_preCPB_auc_60_list')
scenario2.append('NEW_CI_preCPB_mean_list')
scenario2.append('NEW_CI_preCPB_auc_2_list')
scenario2.append('NEW_MAP_intraCPB_CV_list')
scenario2.append('NEW_MAP_intraCPB_ARV_list')
scenario2.append('NEW_MAP_postCPB_CV_list')
scenario2.append('NEW_MAP_postCPB_ARV_list')
scenario2.append('NEW_PP_postCPB_mean_list')
scenario2.append('NEW_PP_postCPB_auc_60_list')
scenario2.append('NEW_CI_postCPB_mean_list')
scenario2.append('NEW_CI_postCPB_auc_2_list')

scenario3=[]
scenario3.append('baseline_rSO2')
scenario3.append('MAP_CV_list')
scenario3.append('MAP_ARV_list')
scenario3.append('NEW_PP_CPBMAPincluded_mean_list')
scenario3.append('NEW_PP_CPBMAPincluded_auc_60_list')
scenario3.append('CI_mean_list')
scenario3.append('CI_auc_2_list')


scenario4=[]
scenario4.append('baseline_rSO2')
scenario4.append('MAP_mean_list')
scenario4.append('MAP_CV_list')
scenario4.append('MAP_ARV_list')
scenario4.append('MAP_duration_65_list')
scenario4.append('MAP_duration_100_list')
scenario4.append('MAP_auc_65_list')
scenario4.append('MAP_auc_100_list')
scenario4.append('PP_mean_list')
scenario4.append('PP_duration_60_list')
scenario4.append('PP_auc_60_list')
scenario4.append('CVP_mean_list')
scenario4.append('CVP_duration_12_list')
scenario4.append('CVP_auc_12_list')
scenario4.append('CI_mean_list')
scenario4.append('CI_duration_2_list')
scenario4.append('CI_auc_2_list')
scenario4.append('NEW_MAP_intraCPB_mean_list')
scenario4.append('NEW_MAP_intraCPB_duration_65_list')
scenario4.append('NEW_MAP_intraCPB_duration_100_list')
scenario4.append('NEW_MAP_intraCPB_auc_65_list')
scenario4.append('NEW_MAP_intraCPB_auc_100_list')

In [None]:
input_model0=input_model0_1+input_model0_2

input_model1=input_model0+input_model1_1+input_model1_2+scenario2

input_model2=input_model1+input_model2_1

X_trainval_model0=X_trainval[input_model0]
X_trainval_model1=X_trainval[input_model1]
X_trainval_model2=X_trainval[input_model2]

X_test_model0=X_test[input_model0]
X_test_model1=X_test[input_model1]
X_test_model2=X_test[input_model2]


y_trainval=np.array(yy4_trainval)
y_test=np.array(yy4_test)

# XGB

In [None]:
set_seeds(0)
model = xgb.XGBClassifier()
parameter_space = {
    'n_estimators': [100, 200, 300, 400],            # Number of boosted trees to fit
    'max_depth': [2,3, 4, 5, 6, 7, 8, 9],                      # Maximum tree depth for base learners
    'learning_rate': [0.0001,0.001, 0.01, 0.1],        # Step size shrinkage used to prevent overfitting
    'subsample': [0.6, 0.7, 0.8, 0.9, 1.0],                   # Subsample ratio of the training instance
    'colsample_bytree': [0.6, 0.7, 0.8, 0.9, 1.0],            # Subsample ratio of columns when constructing each tree
    'gamma': [0,  0.5, 1, 2, 4, 8],                    # Minimum loss reduction required to make a further partition on a leaf node of the tree
    'reg_alpha': [0, 0.5, 1, 2, 4, 8],                  # L1 regularization term on weights
    'reg_lambda': [0, 0.5, 1, 2, 4, 8],                      # L2 regularization term on weights
    'min_child_weight': [1, 2, 3, 4]                   # Minimum sum of instance weight (hessian) needed in a child
}

folds = 4

skf = StratifiedKFold(n_splits=folds, shuffle = True)

              
grid_search = GridSearchCV(model, parameter_space, scoring=['roc_auc', 'average_precision'], verbose=2, return_train_score=True, cv=skf, refit='roc_auc')

grid_search.fit(X_trainval_model0, y_trainval)
best_param = grid_search.best_params_

In [None]:
classifier_XGB=xgb.XGBClassifier(**best_param)
set_seeds(0)
classifier_XGB.fit(X_trainval_model0, y_trainval)
probas_ = classifier_XGB.predict_proba(X_test_model0)
fpr, tpr, thresholds = roc_curve(y_test, probas_[:, 1])
print(auc(fpr, tpr))
precision, recall, _ = precision_recall_curve(y_test, probas_[:, 1])
print(auc(recall, precision))

# RF

In [None]:
set_seeds(0)
model = RandomForestClassifier()
parameter_space = {
    'criterion':['log_loss','gini','entropy'],
    'n_estimators': [25,50, 100, 150, 200,250,300],  # Number of trees
    'max_depth': [3, 4, 5, 6, 7, 8, 9, 10,12],  # Maximum tree depth
    'min_samples_split': [2, 3, 4, 5, 6, 7, 8, 9, 10],  # Minimum samples required to split an internal node
    'min_samples_leaf': [1, 2, 3, 4],  # Minimum samples required to be at a leaf node
    'max_features': ['sqrt', 'log2', None]  # Number of features to consider when looking for the best split
}

folds = 4

skf = StratifiedKFold(n_splits=folds, shuffle = True)

              
grid_search = GridSearchCV(model, parameter_space, scoring=['roc_auc', 'average_precision'], verbose=2, return_train_score=True, cv=skf, refit='roc_auc')

grid_search.fit(X_trainval_model0, y_trainval)
best_param = grid_search.best_params_

In [None]:
classifier_RF=RandomForestClassifier(**best_param)
set_seeds(0)
classifier_RF.fit(X_trainval_model0, y_trainval)
probas_ = classifier_RF.predict_proba(X_test_model0)
fpr, tpr, thresholds = roc_curve(y_test, probas_[:, 1])
print(auc(fpr, tpr))
precision, recall, _ = precision_recall_curve(y_test, probas_[:, 1])
print(auc(recall, precision))

# ET

In [None]:
set_seeds(42)
model = et()
parameter_space = {
    'n_estimators': [50, 100, 150, 200,250,300],  # Number of trees
    'criterion': ['gini', 'entropy'],  # Splitting criterion
    'max_depth': [None, 10, 20, 30, 40, 50],  # Maximum tree depth
    'min_samples_split': [2, 5, 10],  # Minimum samples required to split an internal node
    'min_samples_leaf': [1, 2, 4],  # Minimum samples required to be at a leaf node
    'max_features': ['sqrt', 'log2', None],  # Number of features to consider when looking for the best split
    'bootstrap': [False, True],  # Whether bootstrap samples are used when building trees
    'class_weight': [None, 'balanced']  # Weights associated with classes
}

parameter_space = {
    'n_estimators': [50, 100, 150],  # Number of trees
}

folds = 4

skf = StratifiedKFold(n_splits=folds, shuffle = True)

              
grid_search = GridSearchCV(model, parameter_space, scoring=['roc_auc', 'average_precision'], verbose=2, return_train_score=True, cv=skf, refit='roc_auc')

grid_search.fit(X_trainval_model0, y_trainval)
best_param = grid_search.best_params_

In [None]:
classifier_ET=et(**best_param)
set_seeds(0)
classifier_ET.fit(X_trainval_model0, y_trainval)
probas_ = classifier_ET.predict_proba(X_test_model0)
fpr, tpr, thresholds = roc_curve(y_test, probas_[:, 1])
print(auc(fpr, tpr))
precision, recall, _ = precision_recall_curve(y_test, probas_[:, 1])
print(auc(recall, precision))

# ENS

In [None]:
clf = VotingClassifier(estimators=[('XGB', classifier_XGB),('RF', classifier_RF),('ET', classifier_ET)], voting='soft')
set_seeds(0)
clf.fit(X_trainval_model0, y_trainval)
probas_ = clf.predict_proba(X_test_model0)
fpr, tpr, thresholds = roc_curve(y_test, probas_[:,1])
auroc = auc(fpr, tpr)
precision, recall, _ = precision_recall_curve(y_test, probas_[:, 1])
auprc = auc(recall, precision)
print(auroc,auprc)
joblib.dump(clf, "20240820_models/ENS_classifier.pkl")

In [None]:
shap_values, shap_expected_value, global_shap_df = generate_shap_values(clf, X_test_model0)

In [None]:
shap_values_np=shap_values.to_numpy()

In [None]:
shap.summary_plot(shap_values_np,X_test_model0)

# tableOne

In [None]:
def table1(catcolumn,catcolumns,dfdfdf2,savename):
    nonnormallist=[]
    def _normality(self, x):
        #print(x.name)

        if len(x.values[~np.isnan(x.values)]) >= 20:
            p = stats.shapiro(x.values).pvalue
        else:
            p = None
        # dropna=False argument in pivot_table does not function as expected
        # return -1 instead of None
        if pd.isnull(p):
            return -1
        if p<=0.05:
            nonnormallist.append(x.name)
        return p

    TableOne._normality=_normality

    def my_custom_test(group1, group2):
        """
        Hypothesis test for test_self_defined_statistical_tests
        """
        my_custom_test.__name__ = "mannwhitneyu"
        _, pval= scipy.stats.mannwhitneyu(group1, group2)
        return pval

    nonnormallist=[]
    def _normality(self, x):
        #print(x.name)

        if len(x.values[~np.isnan(x.values)]) >= 20:
            p = stats.shapiro(x.values).pvalue
        else:
            p = None
        # dropna=False argument in pivot_table does not function as expected
        # return -1 instead of None
        if pd.isnull(p):
            return -1
        if p<=0.05:
            nonnormallist.append(x.name)
        return p

    TableOne._normality=_normality

    table1=TableOne(dfdfdf2,categorical=catcolumns,groupby=[catcolumn],normal_test=True,pval=True,htest_name=True,decimals=3)
    nonnormallist=list(set(nonnormallist))
    nonnormallist

    table1=TableOne(dfdfdf2,categorical=catcolumns,groupby=[catcolumn],normal_test=True,pval=True,htest_name=True,nonnormal=nonnormallist,decimals=3)
    try:
        os.mkdir(newtablename)
    except:
        pass
    try:
        os.mkdir(newtablename+'/table1')
    except:
        pass
    catcolumn=catcolumn.replace(' ','_')
    catcolumn=catcolumn.replace('/','_')
    table1.to_html('figures/table1_'+savename+'.html')

In [None]:
categorical=[]
categorical.append('op_type')
categorical.append('re-do')
categorical.append('sex')
categorical.append('emergency')
categorical.append('VAD_use')
categorical.append('HTN')
categorical.append('CKD')
categorical.append('old CVA')
categorical.append('DM')
categorical.append('A.fib')
categorical.append('liver cirrhosis')
categorical.append('CHF')
categorical.append('old MI')
categorical.append('COPD')
categorical.append('Functional class')
categorical.append('acuteMI (1WK)')
categorical.append('NYHA')
categorical.append('BB')
categorical.append('CCB')
categorical.append('ACEi')
categorical.append('ARB')
categorical.append('Statin')
categorical.append('diuretics')
categorical.append('Warfarin')
categorical.append('heparinization')
categorical.append('NOAC')
categorical.append('vaso_T')
categorical.append('prima_T_bi')
categorical.append('dobu_T_bi')
categorical.append('Katz')

In [None]:
X_trainvaltest2=X_trainvaltest.copy()
X_trainvaltest2['y']=y_trainvaltest

In [None]:
table1('y',categorical,X_trainvaltest2, 'y4')

# ROC PR curves

In [None]:
def ROC_PR_curves(roc_pr):
    fig, ax =plt.subplots(figsize=(25,25))
    
    for spine in ax.spines.values():
        spine.set_color('black')
        spine.set_linewidth(2)  # Adjust the line width as needed

    if roc_pr=='ROC':
        plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',alpha=1)
    plt.xticks(fontsize=28)
    plt.yticks(fontsize=28)
    

    clf_loaded = joblib.load('20240820_models/ENS_classifier.pkl')
    set_seeds(0)
    probas_ = clf_loaded.predict_proba(X_test_model0)
    fpr, tpr, thresholds = roc_curve(y_test, probas_[:, 1])
    auroc = round(auc(fpr, tpr),3)
    precision, recall, _ = precision_recall_curve(y_test, probas_[:, 1])
    auprc = round(auc(recall, precision),3)
            
    linewidth=3
    
    if roc_pr=='ROC':
        plt.plot(fpr, tpr, color='blue',label=f'ENS model AUROC = {auroc}' ,lw=linewidth, alpha=1)
    if roc_pr=='PR':
        plt.plot(recall, precision, color='blue',label=f'ENS model AUPRC = {auprc}' ,lw=linewidth, alpha=1)
    
    plt.xlim([-0.00, 1.00])
    plt.ylim([-0.00, 1.00])
    plt.yticks(np.arange(0.2, 1.2, step=0.2))
    plt.xticks(fontsize =37)
    plt.yticks(fontsize =37)
    
    if roc_pr == 'ROC':
        plt.xlabel('1-Specificity', fontsize=50)
        plt.ylabel('Sensitivity', fontsize=50)
        plt.title('ROC curves', fontsize=55)
        legend = plt.legend(loc="lower right", fontsize=45)
    if roc_pr == 'PR':
        plt.xlabel('Recall', fontsize=50)
        plt.ylabel('Precision', fontsize=50)
        plt.title('PR curves', fontsize=55)
        legend = plt.legend(loc="upper right", fontsize=45)
    
    # Set line width of legend box
    legend.get_frame().set_linewidth(4)
    
    # Set line width of legend lines
    for line in legend.get_lines():
        line.set_linewidth(8)  # Adjust the line width as needed
    
    if roc_pr == 'ROC':
        plt.savefig('figures/20230821_ROC.png', transparent=True)
        plt.show()
    if roc_pr == 'PR':
        plt.savefig('figures/20230821_PR.png', transparent=True)
        plt.show()

In [None]:
for rocprtemp in ['ROC','PR']:
    ROC_PR_curves(roc_pr=rocprtemp)

# SHAP summary plot

In [None]:
def SHAP_summary_plot():
    
    clf_loaded = joblib.load('20240820_models/ENS_classifier.pkl')
    set_seeds(0)

    shap_values, shap_expected_value, global_shap_df = generate_shap_values(clf_loaded, X_test_model0)
    shap_values_np=shap_values.to_numpy()

    shap.summary_plot(shap_values_np,X_test_model0,max_display=10,show=False)
    savename='figures/20241010_SHAP_summary.png'
    plt.savefig(savename,transparent=True)

    plt.close()
        
       

In [None]:
SHAP_summary_plot()

# SHAP dependence plot

In [None]:
def SHAP_dependence_plot(X_test,model_type,scenario_num, y_trainval, y_test,excludelistnum,ytype,seednum):

    clf_loaded = joblib.load('20240820_models/ENS_classifier.pkl')
    set_seeds(0)
   
    shap_values, shap_expected_value, global_shap_df = generate_shap_values(clf_loaded, X_test_model0)
    shap_values_np=shap_values.to_numpy()
    
    shap_importance = np.abs(shap_values_np).mean(axis=0)

    top_10_feature_indices = np.argsort(shap_importance)[-15:][::-1]

    top_10_feature_names = X_test_model0.columns[top_10_feature_indices]

    for feature in top_10_feature_names:
        shap.dependence_plot(feature, shap_values_np, X_test_model0, interaction_index=None, show=False)
        plt.axhline(y=0, color='red', linestyle='--', linewidth=1.5)
        savename='figures/20241010_SHAP_dependence_'
        savename+=(feature+ '.png')
        savename=savename.replace('*','')
        savename=savename.replace('<','')
        savename=savename.replace('~','')
        plt.tight_layout()
        plt.savefig(savename,transparent=True)

        plt.close()

In [None]:
SHAP_dependence_plot()