In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# create results dir
! mkdir -p nhanes_output/xgb/TrainSetImputer/interactions
! mkdir -p nhanes_output/xgb/TrainSetMahalanobisImputer/interactions
! mkdir -p nhanes_output/xgb/IterativeImputerEnhanced/interactions
! mkdir -p nhanes_output/xgb/TrainSetImputer/relevances
! mkdir -p nhanes_output/xgb/TrainSetMahalanobisImputer/relevances
! mkdir -p nhanes_output/xgb/IterativeImputerEnhanced/relevances

! mkdir -p nhanes_output/rf/TrainSetImputer/interactions
! mkdir -p nhanes_output/rf/TrainSetMahalanobisImputer/interactions
! mkdir -p nhanes_output/rf/IterativeImputerEnhanced/interactions
! mkdir -p nhanes_output/rf/TrainSetImputer/relevances
! mkdir -p nhanes_output/rf/TrainSetMahalanobisImputer/relevances
! mkdir -p nhanes_output/rf/IterativeImputerEnhanced/relevances

# NHANES

In [None]:
from pred_diff.datasets.loadnhanes import NHANES_DataFrame
import numpy as np

from pred_diff import preddiff
from pred_diff.tools.preddiff_plotting import *
from pred_diff.imputers.impute import *
import pickle
import matplotlib.pyplot as plt
from matplotlib  import cm

from pred_diff.tools import init_plt

# paper style
# init_plt.update_NHANES()
# init_plt.update_figsize(fig_width_pt=234.88)           # 234.88 is column in paper
# size_title = 6

# default
plt.style.use('default')
size_title = 12

In [None]:
nhanes_df = NHANES_DataFrame()

In [None]:
print("y>0 (not surviving):",np.sum(nhanes_df.y>0),"\ny<0 (surviving):",np.sum(nhanes_df.y<0))

## Fit model

In [None]:
# model_selection="xgb"
model_selection="rf"

In [None]:
if(model_selection == "xgb"):
    import xgboost
    #c.f. https://github.com/suinleelab/treeexplainer-study/blob/master/notebooks/mortality/NHANES%20I%20Analysis.ipynb

    params = {
        "learning_rate": 0.001,
        "n_estimators": 6765,
        "max_depth": 4,
        "subsample": 0.5, 
        "reg_lambda": 5.5,
        "reg_alpha": 0,
        "colsample_bytree": 1
    }

    reg = xgboost.XGBRegressor(
        max_depth=params["max_depth"],
        n_estimators=params["n_estimators"],
        learning_rate=params["learning_rate"],#math.pow(10, params["learning_rate"]),
        subsample=params["subsample"],
        reg_lambda=params["reg_lambda"],
        colsample_bytree=params["colsample_bytree"],
        reg_alpha=params["reg_alpha"],
        n_jobs=16,
        random_state=1,
        objective="survival:cox",
        base_score=1
    )
    reg.fit(
        nhanes_df.X_strain, nhanes_df.y_strain, verbose=500,
        eval_set=[(nhanes_df.X_valid, nhanes_df.y_valid)],
        #eval_metric="logloss",
        early_stopping_rounds=10000
    )
elif(model_selection == "rf"):
    from sklearn.ensemble import RandomForestRegressor
    
    reg = RandomForestRegressor(n_estimators=1000,max_depth=4)

    reg.fit(nhanes_df.X_strain_imp, nhanes_df.y_strain)

else:
    assert(False)

In [None]:
def c_statistic_harrell(pred, labels):
    total = 0
    matches = 0
    for i in range(len(labels)):
        for j in range(len(labels)):
            if labels[j] > 0 and abs(labels[i]) > labels[j]:
                total += 1
                if pred[j] > pred[i]:
                    matches += 1
    return matches/total

In [None]:
if(model_selection == "xgb"):
    print(c_statistic_harrell(reg.predict(nhanes_df.X_test), nhanes_df.y_test))#,output_margin=True
else:
    print(c_statistic_harrell(reg.predict(nhanes_df.X_test_imp), nhanes_df.y_test))

In [None]:
pickle.dump(reg, open(model_selection+".pkl", 'wb'))

In [None]:
reg = pickle.load(open(model_selection+".pkl", 'rb'))

## Compute Relevances

In [None]:
imputer_selection = "TrainSetImputer"
# imputer_selection = "TrainSetMahalanobisImputer"
#imputer_selection = "IterativeImputerEnhanced"

In [None]:
if(imputer_selection=="IterativeImputer"):
    imputer = IterativeImputer
elif(imputer_selection=="IterativeImputerEnhanced"):
    imputer = IterativeImputerEnhanced
elif(imputer_selection=="TrainSetImputer"):
    imputer = TrainSetImputer
elif(imputer_selection=="TrainSetMahalanobisImputer"):
    imputer = TrainSetMahalanobisImputer
elif(imputer_selection=="GaussianProcessImputer"):
    imputer = GaussianProcessImputer
elif(imputer_selection=="MedianImputer"):
    imputer = MedianImputer
else:
    assert(False)
if(imputer_selection == "TrainSetMahalanobisImputer"):
    train_ids = np.random.permutation(range(len(nhanes_df.X_strain_imp)))
    mvi = preddiff.PredDiff(reg,nhanes_df.X_strain_imp.iloc[train_ids], imputer_cls=imputer, sigma=1,gpus=1,batch_size_test=256)
else:
    mvi = preddiff.PredDiff(reg,nhanes_df.X_strain_imp, imputer_cls=imputer)

m_list = mvi.relevances(nhanes_df.X_test_imp,n_imputations=100)

In [None]:
m_stats = calculate_global_preddiff_stats(m_list,nhanes_df.X.columns)


In [None]:
plot_global_preddiff_stats(m_stats,title="Global feature importance",filename="./nhanes_output/"+model_selection+"/"+imputer_selection+"/relevances/nhanes_global_"+imputer_selection+"_"+model_selection+".pdf")

In [None]:
selected_cols = list(set(["age", "systolic blood pressure", "sex female"]).union(set(m_stats.iloc[:5].col)))

selected_ids = [np.where(nhanes_df.columns_all == x)[0][0] for x in selected_cols]
selected_xlabels = [x.replace("age","age[a]").replace("systolic blood pressure","systolic blood pressure [mmHg]") for x in selected_cols]
for j,i in enumerate(selected_ids):
    plt.xlabel(selected_cols[j])
    plt.ylabel("relevance")
    plt.title(selected_cols[j], size=size_title)
    plt.scatter(nhanes_df.X_test[nhanes_df.columns_all[i]], m_list[i]['mean'],c=nhanes_df.X_test["age"],cmap=cm.coolwarm,zorder=1)#m_list[i]['std']
    xlim = plt.xlim()
    ylim = plt.ylim()
    plt.errorbar(nhanes_df.X_test[nhanes_df.columns_all[i]], m_list[i]['mean'], np.stack([m_list[i]['high']-m_list[i]['mean'],m_list[i]['mean']-m_list[i]['low']],axis=0),ecolor="grey",marker='',linestyle='None',elinewidth=0.5,capsize=1,alpha=0.8,zorder=0)
    if(selected_cols[j]=="white blood cells"):
        xlim = (0,20)
    plt.xlim(xlim)
    plt.ylim(ylim)
    cbar = plt.colorbar()
    cbar.ax.set_title("age", size=size_title)
    plt.tight_layout(pad=0.1)
    plt.savefig("./nhanes_output/"+model_selection+"/"+imputer_selection+"/relevances/nhanes_"+selected_cols[j].replace(" ","_")+"_"+imputer_selection+"_"+model_selection+".pdf",bbox_inches='tight')
    plt.show()
    

## Compute interaction relevances

In [None]:
#pick 5 most important features
interaction_vars = np.array(m_stats.iloc[:5].col) #["age","systolic blood pressure","sex female"]

interaction_cols =[]

for i in range(len(interaction_vars)):
    for j in range(i+1,len(interaction_vars)):
        interaction_cols.append([[interaction_vars[i]],[interaction_vars[j]]])

interaction_cols_txt = ["&".join(i1)+" AND \n"+"&".join(i2) for [i1,i2] in interaction_cols]

In [None]:
print(interaction_vars)

In [None]:
m_int = mvi.interactions(nhanes_df.X_test_imp, interaction_cols, n_imputations=200)

In [None]:
m_int_stats = calculate_global_preddiff_stats(m_int,interaction_cols_txt)

In [None]:
plot_global_preddiff_stats(m_int_stats,title="Global interaction importance",filename="./nhanes_output/"+model_selection+"/"+imputer_selection+"/interactions/nhanes_interaction_global_"+imputer_selection+"_"+model_selection+".pdf")

In [None]:
for i,(ic,ict) in enumerate(zip(interaction_cols,interaction_cols_txt)):
    fig, axs = plt.subplots(1,2, figsize=(15,5))
    for j in [0,1]:
        ax = axs[j]
        im=ax.scatter(nhanes_df.X_test[ic[j][0]], m_int[i]['mean'], c=nhanes_df.X_test[ic[1-j][0]],cmap=cm.coolwarm,zorder=1)
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax.errorbar(nhanes_df.X_test[ic[j][0]], m_int[i]['mean'], np.stack([m_int[i]['high']-m_int[i]['mean'],m_int[i]['mean']-m_int[i]['low']],axis=0), marker='', linestyle='None',elinewidth=0.5,ecolor="grey",capsize=1,alpha=0.8,zorder=0)
        
        if(ic[j][0]=="white blood cells"):
            xlim = (0,20)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_title(ict, size=size_title)
        ax.set_xlabel(ic[j][0])
        cbar = fig.colorbar(im, ax=ax)
        cbar.ax.set_title(ic[1-j][0], size=size_title)
        ax.set_ylabel("Interaction relevance")

#again for export
for i,(ic,ict) in enumerate(zip(interaction_cols,interaction_cols_txt)):
    for j in [0,1]:
        plt.figure()
        plt.scatter(nhanes_df.X_test[ic[j][0]], m_int[i]['mean'], c=nhanes_df.X_test[ic[1-j][0]],cmap=cm.coolwarm,zorder=1)
        xlim = plt.xlim()
        ylim = plt.ylim()
        plt.errorbar(nhanes_df.X_test[ic[j][0]], m_int[i]['mean'], np.stack([m_int[i]['high']-m_int[i]['mean'],m_int[i]['mean']-m_int[i]['low']],axis=0), marker='', linestyle='None',elinewidth=1,ecolor="grey",capsize=0.5,alpha=0.8,zorder=0)
        if(ic[j][0]=="white blood cells"):
            xlim = (0,20)
        plt.xlim(xlim)
        plt.ylim(ylim)
    
        plt.title(ict, size=size_title)
        plt.xlabel(ic[j][0])
        cbar = plt.colorbar()
        cbar.ax.set_title(ic[1-j][0], size=size_title)
        plt.ylabel("Interaction relevance")
        plt.tight_layout(pad=0.1)
        plt.savefig("./nhanes_output/"+model_selection+"/"+imputer_selection+"/interactions/nhanes_interaction_"+ic[j][0].replace(" ","_")+"_AND_"+ic[1-j][0].replace(" ","_")+imputer_selection+"_"+model_selection+".pdf", bbox_inches='tight')

# SHAP comparison

In [None]:
import shap
import matplotlib.pyplot as plt

In [None]:
explainer = shap.TreeExplainer(reg)
shap_relevances = explainer.shap_values(nhanes_df.X_test)

In [None]:
f = plt.figure(figsize=(4,6))
shap.summary_plot(
    shap_relevances, nhanes_df.X_test, feature_names=nhanes_df.X.columns, plot_type="bar",
    max_display=15, plot_size=None, show=True
)
plt.show()