In [None]:
#versions: shap==0.48.0, pandas==2.1.1, matplotlib==3.8.0, numpy==1.26.4


In [None]:
import shap
import os
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import re
import numpy as np
import math

In [None]:
FEATURE_SET = "rBL"#one of "BL", "BL+VOL", "BL+RAD", "BL+VOL+RAD", "rBL", "rBL+VOL", "rBL+RAD", "rBL+VOL+RAD"
DATASET_SEL = "CN+MCI" #one of "CN+MCI", "MCI"
MODEL="CoxnetSurvivalAnalysis" # one of "CoxnetSurvivalAnalysis", "CoxPHSurvivalAnalysis", "ExtraSurvivalTrees", "GradientBoostingSurvivalAnalysis", "RandomSurvivalForest"
SAMPLE_WEIGHTS=True # True or False


In [None]:
if SAMPLE_WEIGHTS:
    MODEL_DIR = os.path.join("./results/"+MODEL+"sw_"+DATASET_SEL+"_"+FEATURE_SET+"/")
else:
    MODEL_DIR = os.path.join("./results/"+MODEL+"_"+DATASET_SEL+"_"+FEATURE_SET+"/")

In [None]:
explainability_mapping=dict()
explainability_mapping["FAQ"]=["Intensive care not needed","Intensive care needed"]
explainability_mapping["CDRSB"]=["No cognitive impairment","Advanced cognitive impairment"]
explainability_mapping["CDGLOBAL"]=["No cognitive impairment","Advanced cognitive impairment"]
explainability_mapping["mPACCtrailsB"]=["Advanced cognitive impairment","No cognitive impairment"]
explainability_mapping["LDELTOT"]=["Memory problems","No memory problems"]
explainability_mapping["LIMMTOTAL"]=["Memory problems","No memory problems"]
explainability_mapping["ADAS11"]=["No cognitive impairment","Advanced cognitive impairment"]
explainability_mapping["ADAS13"]=["No cognitive impairment","Advanced cognitive impairment"]
explainability_mapping["ADASQ4"]=["No cognitive impairment","Advanced cognitive impairment"]

explainability_mapping["RAVLT_perc_forgetting"]=["No cognitive impairment","Advanced cognitive impairment"]
explainability_mapping["RAVLT_forgetting"]=["No cognitive impairment","Advanced cognitive impairment"]
explainability_mapping["APOE4_0.0"]=["APOE4 allele present ","APOE4 allele not present"]
explainability_mapping["APOE4_2.0"]=["No duplicate APOE4 allele","Duplicate APOE4 allele"]
explainability_mapping["APOE4_1.0"]=["No single APOE4 allele","Single APOE4 allele"]

explainability_mapping["RAVLT_immediate"]=["Advanced cognitive impairment","No cognitive impairment"]
explainability_mapping["RAVLT_learning"]=["Advanced cognitive impairment","No cognitive impairment"]
explainability_mapping["PTMARRY_Unknown"]=["Known marital status","Unknown marital status"]
explainability_mapping["AGE"]=["Young age","Old age"]
explainability_mapping["PTMARRY_Divorced"]=["Marital status not divorced","Marital status divorced"]
explainability_mapping["PTMARRY_Widowed"]=["Marital status not widowed","Marital status widowed"]
explainability_mapping["PTMARRY_Never married"]=["Marital status not never married","Marital status never married"]
explainability_mapping["EcogPtDivatt"]=["Patient reported no attention problems","Patient reported attention problems"]
explainability_mapping["EcogSPDivatt"]=["Study partner reported no attention problems","Study partner reported attention problems"]

explainability_mapping["EcogPtLang"]=["Patient reported no language problems","Patient reported language problems"]
explainability_mapping["EcogPtPlan"]=["Patient reported no planning problems","Patient reported planning problems"]
explainability_mapping["EcogPtVisspat"]=["Patient reported no problems with visual-spatial orientation","Patient reported problems with visual-spatial orientation"]#todo

explainability_mapping["EcogSPMem"]=["Study partner reported no memory problems","Study partner reported memory problems"]
explainability_mapping["EcogSPLang"]=["Study partner reported no language problems","Study partner reported language problems"]
explainability_mapping["EcogSPPlan"]=["Study partner reported no planning problems","Study partner reported planning problems"]

explainability_mapping["EcogPtTotal"]=["Patient reported no impairment in daily living","Patient reported severe impairment in daily living"]
explainability_mapping["EcogSPTotal"]=["Study partner reported no impairment in daily living","Study partner reported severe impairment in daily living"]
explainability_mapping["PTRACCAT_Unknown"]=["Known race","Unknown race"]
explainability_mapping["PTGENDER_Female"]=["Male subject","Female subject"]
explainability_mapping["PTGENDER_Male"]=["Female subject","Male subject"]

explainability_mapping["PTETHCAT_Unknown"]=["Known ethnic category","Unknown ethnic category"]
explainability_mapping["PTRACCAT_More than one"]=["Less than one race","More than one race"]
explainability_mapping["PTRACCAT_Asian"]=["Non asian","Asian"]
explainability_mapping["PTRACCAT_Black"]=["Non black","Black"]
explainability_mapping["PTRACCAT_Am Indian/Alaskan"]=["Non American Indian/Alaskan","American Indian/Alaskan"]
explainability_mapping["PTRACCAT_Hawaiian/Other PI"]=["Non Hawaiian/Other PI","Hawaiian/Other PI"]
explainability_mapping["PTETHCAT_Hisp/Latino"]=["Non Hispanic/Latino","Hispanic/Latino"]

explainability_mapping["MOCA"]=["Advanced cognitive impairment","No cognitive impairment"]
explainability_mapping["TRAB"]=["Advanced cognitive impairment","No cognitive impairment"]
explainability_mapping["MMSE"]=["Advanced cognitive impairment","No cognitive impairment"]
explainability_mapping["rh_"]=["Small volume","High volume"]
explainability_mapping["lh_"]=["Small volume","High volume"]
explainability_mapping["Left-"]=["Small volume","High volume"]
explainability_mapping["Right-"]=["Small volume","High volume"]
explainability_mapping["CSF"]=["Small volume","High volume"]
explainability_mapping["Brain-Stem"]=["Small volume","High volume"]
explainability_mapping["3rd-Ventricle"]=["Small volume","High volume"]
explainability_mapping["3rd-Ventricle"]=["Small volume","High volume"]
explainability_mapping["lhCerebralWhiteMatterVol"]=["Small volume","High volume"]
explainability_mapping["EstimatedTotalIntraCranialVol"]=["Small volume","High volume"]


explainability_mapping["PTEDUCAT"]=["Low educational attainment","High educational attainment"]

explainability_mapping["MagStrength_3.0"]=["Magnetic Strength is 1.5 Tesla","Magnetic Strength is 3.0 Tesla"]
explainability_mapping["MagStrength_1.5"]=["Magnetic Strength is 3.0 Tesla","Magnetic Strength is 1.5 Tesla"]
explainability_mapping["PHS"]=["Low genetic risk to develop AD","High genetic risk to develop AD"]
explainability_mapping["CIR"]=["Low genetic risk to develop AD","High genetic risk to develop AD"]

In [None]:
ex_df=pd.read_csv("./data/MappingOfRadiomicsFeatures_fin_v4.csv",sep=";")

In [None]:
for (row) in explainability_mapping:
    low_value=explainability_mapping[row][0]
    if " "in low_value[round(len(low_value)/4):]:
        idx1=low_value.index(" ",round(len(low_value)/4))
        low_value=low_value[0:idx1]+"\n"+low_value[idx1:]
    if " "in low_value[2*round(len(low_value)/4):]:
        idx1=low_value.index(" ",2*round(len(low_value)/4))
        low_value=low_value[0:idx1]+"\n"+low_value[idx1:]
    if " "in low_value[3*round(len(low_value)/4):]:
        idx1=low_value.index(" ",3*round(len(low_value)/4))
        low_value=low_value[0:idx1]+"\n"+low_value[idx1:]
    low_value=low_value.replace("\n ","\n")
    low_value=low_value.replace("\n\n","\n")
    high_value=explainability_mapping[row][1]
    if " "in high_value[round(len(high_value)/4):]:
        idx1=high_value.index(" ",round(len(high_value)/4))
        high_value=high_value[0:idx1]+"\n"+high_value[idx1:]
    if " "in high_value[2*round(len(high_value)/4):]:
        idx1=high_value.index(" ",2*round(len(high_value)/4))
        high_value=high_value[0:idx1]+"\n"+high_value[idx1:]
    if " "in high_value[3*round(len(high_value)/4):]:
        idx1=high_value.index(" ",3*round(len(high_value)/4))
        high_value=high_value[0:idx1]+"\n"+high_value[idx1:]
    high_value=high_value.replace("\n ","\n")
    high_value=high_value.replace("\n\n","\n")
    explainability_mapping[row]=[low_value,high_value]

In [None]:
for (i,row) in ex_df.iterrows():
    low_value=row["Low value"]
    if " "in low_value[round(len(low_value)/4):]:
        idx1=low_value.index(" ",round(len(low_value)/4))
        low_value=low_value[0:idx1]+"\n"+low_value[idx1:]
    if " "in low_value[2*round(len(low_value)/4):]:
        idx1=low_value.index(" ",2*round(len(low_value)/4))
        low_value=low_value[0:idx1]+"\n"+low_value[idx1:]
    if " "in low_value[3*round(len(low_value)/4):]:
        idx1=low_value.index(" ",3*round(len(low_value)/4))
        low_value=low_value[0:idx1]+"\n"+low_value[idx1:]
    low_value=low_value.replace("\n ","\n")
    low_value=low_value.replace("\n\n","\n")
    high_value=row["High value"]
    if " "in high_value[round(len(high_value)/4):]:
        idx1=high_value.index(" ",round(len(high_value)/4))
        high_value=high_value[0:idx1]+"\n"+high_value[idx1:]
    if " "in high_value[2*round(len(high_value)/4):]:
        idx1=high_value.index(" ",2*round(len(high_value)/4))
        high_value=high_value[0:idx1]+"\n"+high_value[idx1:]
    if " "in high_value[3*round(len(high_value)/4):]:
        idx1=high_value.index(" ",3*round(len(high_value)/4))
        high_value=high_value[0:idx1]+"\n"+high_value[idx1:]
    high_value=high_value.replace("\n ","\n")
    high_value=high_value.replace("\n\n","\n")
    explainability_mapping[row["Feature Name"]]=[low_value,high_value]

In [None]:
def ipcw_brier_scorer(estimator, X, y):
    scores = []
    for train_idx, test_idx in index_pairs:  # <- diese benutzt du direkt
        X_tr, X_te = X.iloc[train_idx], X.iloc[test_idx]
        y_tr, y_te = y[train_idx], y[test_idx]
        w_tr, _ = weights[train_idx], weights[test_idx]

        estimator.fit(X_tr, y_tr, sample_weight=w_tr)
        surv_fns = estimator.predict_survival_function(X_te)
        preds = np.asarray([[fn(t) for t in [4.0, 8.0]] for fn in surv_fns])

        score = integrated_brier_score(y_tr, y_te, preds, [4.0, 8.0])
        scores.append(score)
    return -np.mean(scores)

In [None]:
filenameCSV=MODEL_DIR+"/resultsCV.csv"
trainingDSCSV=MODEL_DIR+"/training.csv"
testDSCSV=MODEL_DIR+"/test.csv"
filename=MODEL_DIR+"/model_bayes_optimization.sav"
clf=pickle.load(open(filename, "rb"))
df_ges= pd.read_csv(trainingDSCSV)
df_ges_test= pd.read_csv(testDSCSV)
df_ges=df_ges.set_index(["PTID","IMAGEUID"])
df_ges_test=df_ges_test.set_index(["PTID","IMAGEUID"])
savename_shapley_data=MODEL_DIR+"/ShapValues_test_ADNI.pkl"
with open(savename_shapley_data, "rb") as input_file:
    shap_values = pickle.load(input_file)

number_cols=math.ceil(df_ges_test.shape[1]/2)
fig, axs = plt.subplots(2, number_cols,figsize=(number_cols*4,8))
i = 0
for col in [x for _, x in sorted(zip(np.abs(shap_values.values).sum(axis=0).tolist(), df_ges_test.columns.tolist()),reverse=True)]:
    col_new=col.replace("original_","")
    col_new = re.sub("_[0-9]+_", "_", col_new)
    col_new = re.sub("ctx-", "", col_new)
    col_new = re.sub("_volume", "", col_new)
    j=0
    for value in [i for i, ltr in enumerate(col_new) if ltr == "_"]:
        col_new=col_new[0:value+1+j]+"\n"+col_new[value+1+j:]
        j+=1
    if(number_cols>1):
        shap.plots.scatter(shap_values[:, col],ax=axs[i//number_cols,i%number_cols],show=False,ymin=shap_values.values.min(),ymax=shap_values.values.max(),title=col_new)
        axs[i//number_cols,i%number_cols].axhline(y=0.0, color="black",linestyle="--", linewidth=1)
        axs[i//number_cols,i%number_cols].set_ylabel("SHAP value")
        axs[i//number_cols,i%number_cols].set_xlabel("")
        if (col in explainability_mapping):
            axs[i//number_cols,i%number_cols].set_xticks([shap_values[:, col].data.min(),shap_values[:, col].data.max()])
            axs[i//number_cols,i%number_cols].set_xticklabels([explainability_mapping[col][0],explainability_mapping[col][1]])
        else:
            isExplained=False
            for value in explainability_mapping.keys():
                if col_new.replace("\n","").startswith(value):
                    axs[i//number_cols,i%number_cols].set_xticks([shap_values[:, col].data.min(),shap_values[:, col].data.max()])
                    axs[i//number_cols,i%number_cols].set_xticklabels([explainability_mapping[value][0],explainability_mapping[value][1]])

                    isExplained=True
                    break;
            if( not isExplained):
                print("Not explained: " +col_new)
        if(not i%number_cols==0):
            axs[i//number_cols,i%number_cols].set_yticks([])
        else:
            yticks=axs[i//number_cols,i%number_cols].get_yticks()
            y_labels = [str(np.round(label,3)) for label in yticks]

            y_labels[0] = "Protective"  # Ersetze den ersten y-Tick mit Text
            y_labels[-1] = "Pathogenic"
            axs[i//number_cols,i%number_cols].set_yticks(yticks)
            axs[i//number_cols,i%number_cols].set_yticklabels(y_labels)

    else:
        shap.plots.scatter(shap_values[:, col],ax=axs[i//number_cols],show=False,ymin=shap_values.values.min(),ymax=shap_values.values.max(),title=col_new)
        axs[i//number_cols].axhline(y=0.0, color="black",linestyle="--", linewidth=1)
        axs[i//number_cols].set_ylabel("SHAP value")
        axs[i//number_cols].set_xlabel("")
        if (col in explainability_mapping):
            axs[i//number_cols].set_xticks([shap_values[:, col].data.min(),shap_values[:, col].data.max()])
            axs[i//number_cols].set_xticklabels([explainability_mapping[col][0],explainability_mapping[col][1]])

        else:
            isExplained=False;
            for value in explainability_mapping.keys():
                if col_new.replace("\n","").startswith(value):
                    axs[i//number_cols].set_xticks([shap_values[:, col].data.min(),shap_values[:, col].data.max()])
                    axs[i//number_cols].set_xticklabels([explainability_mapping[value][0],explainability_mapping[value][1]])
                    isExplained=True
                    break;
            index_col=0
            for t in axs[i//number_cols].get_xticklabels():
                if index_col==0:
                    t.set_horizontalalignment("left")
                else:
                    t.set_horizontalalignment("right")
                index_col+=1
            if not isExplained:
                print("Not explained: " +col_new)
        if(not i%number_cols==0):
            axs[i//number_cols].set_yticks([])
        else:
            yticks=axs[i//number_cols].get_yticks()
            y_labels = [str(np.round(label,3)) for label in yticks]

            y_labels[0] = "Protective"  # Ersetze den ersten y-Tick mit Text
            y_labels[-1] = "Pathogenic"
            axs[i//number_cols].set_yticks(yticks)
            axs[i//number_cols].set_yticklabels(y_labels)

    i+=1
for ax in axs.flat[i:]:
    ax.remove()
fig.tight_layout()
savename="./plots/"+MODEL_DIR.split("/")[-2]+".pdf"
fig.savefig(savename, bbox_inches="tight")
plt.show()