In [None]:
# !pip install shap

In [None]:
import os
import random
import sys

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from dotenv import load_dotenv
from joblib import load
from sklearn import metrics
import scipy.stats as st

module_path = os.path.abspath(os.path.join("../scripts"))
if module_path not in sys.path:
    sys.path.append(module_path)

import shap
shap.initjs()

load_dotenv()

# torch.manual_seed(int(os.getenv("RANDOM_SEED")))
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.cuda.manual_seed(int(os.getenv("RANDOM_SEED")))
# random.seed(int(os.getenv("RANDOM_SEED")))
# np.random.seed(int(os.getenv("RANDOM_SEED")))

from run_models import CauseSpecificNet, DeepHit, get_preprocessed_datasets
from utils import VTEDataLoader, get_logger, get_parent_dir, plot_roc, plot_calibration
from vte_deephit import get_datasets, get_best_params

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cycler

color_list = [
    "#E64B35FF",
    "#4DBBD5FF",
    "#00A087FF",
    "#3C5488FF",
    "#F39B7FFF",
    "#8491B4FF",
    "#91D1C2FF",
    "#DC0000FF",
    "#7E6148FF",
    "#B09C85FF",
]
matplotlib.rcParams["font.family"] = "Arial"
matplotlib.rcParams["axes.prop_cycle"] = cycler(color=color_list)
plt.rcParams["font.size"] = 18
plt.rcParams["axes.linewidth"] = 2

In [None]:
logger = get_logger("shap-notebook")

In [None]:
datasets = get_datasets();

In [None]:
x_train = datasets.get("x_train")
x_test = datasets.get("x_test")
y_train = datasets.get("y_train")
y_train_6 = datasets.get("y_train_6")
y_test = datasets.get("y_test")
x_train_ks = datasets.get("x_train_ks")
x_test_ks = datasets.get("x_test_ks")
y_test_ks = datasets.get("y_test_ks")
labtrans = datasets.get("labtrans")
labtrans_6 = datasets.get("labtrans_6")

In [None]:
feature = "no_genes"
n = 30


logger.info(f"Running for feature: {feature}")
params = load(get_parent_dir() / f"models/{feature}/params.pkl")
models = []
for i in range(n):
    net = CauseSpecificNet(**params)
    m = DeepHit(net)
    m.load_model_weights(get_parent_dir() / f"models/{feature}/model_{i}.pt")
    models.append(m)



In [None]:
# refit original model on new data
transformation_pipeline = load(get_parent_dir() / f"models/{feature}/preprocessing_fit.joblib")
features_train = transformation_pipeline.fit_transform(x_train).astype("float32");
features_test = transformation_pipeline.transform(x_test).astype("float32");
# cols = ["AGE", "ALBUMIN", "ALKPHOS", "ALT", "AST", "CALCIUM", "CHEMO_alkylating",
#         "CHEMO_antibiotic", "CHEMO_antimetabolite", "CHEMO_antimitotic", "CHEMO_cdki",
#         "CHEMO_egfri", "CHEMO_immune", "CHEMO_multikinase", "CHEMO_other", "CHEMO_parpi",
#         "CHEMO_platin", 'CHEMO_serm', 'CHEMO_vegfi',
#         'CHLORIDE', 'CO2', 'CREATININE', 'DX_delta',
#         'GLUCOSE', 'HB', 'POTASSIUM', 'PROC_delta',
#         'SODIUM', 'TBILI', 'TPROTEIN', 'UREA', "CANCER_TYPE_FINAL", "SAMPLE_TYPE"]

In [None]:
transformation_pipeline.get_feature_names_out()

In [None]:
def torch_model_wrapper(x):
    cifs = []
    for m in models:
        cifs.append(m.predict_cif(x))
    cif = np.mean(cifs, dtype=np.float32, axis=0)
    return cif[0][180, :]

In [None]:
y_train[1].shape

In [None]:
features_train.shape

In [None]:
transformation_pipeline.get_feature_names_out()

In [None]:
train_df = pd.DataFrame(features_train, columns=transformation_pipeline.get_feature_names_out())

In [None]:
train_df.cat__SEX_M.value_counts()

In [None]:
from sklearn.utils import resample
background = resample(features_train, replace=False, n_samples=20, random_state=42, stratify=y_train[1])

In [None]:
import shap
explainer = shap.KernelExplainer(torch_model_wrapper, background)


In [None]:
explainer.expected_value

In [None]:
test_subset = resample(features_test, replace=False, n_samples=1000, random_state=42, stratify=y_test[1])

In [None]:
shap_values = np.load(f"shap_{feature}.npy")
# shap_values = explainer.shap_values(test_subset, n_samples=100)
# np.save(f"shap_{feature}", shap_values)

In [None]:
subset_df = pd.DataFrame(test_subset, columns=list(transformation_pipeline.get_feature_names_out()))

In [None]:
shap.summary_plot(shap_values, subset_df, plot_type="bar", show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_summary_mean.svg", dpi=300, format="svg", bbox_inches="tight")

In [None]:
shap.summary_plot(shap_values, subset_df,
                  plot_size=(10, 10),
                  title="SHAP Main Model",
                  show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_summary.svg", dpi=300, format="svg")

In [None]:
shap_values_single = explainer.shap_values(subset_df.iloc[0], nsamples=1000)
# fig = plt.figure(figsize=(20, 10))
shap.force_plot(explainer.expected_value, shap_values_single, subset_df.iloc[0], show=False, matplotlib=True, text_rotation=20)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_explain_1.svg", dpi=300, format="svg", bbox_inches="tight")

In [None]:
shap_values_single = explainer.shap_values(subset_df.iloc[5,:], nsamples=1000)
shap.force_plot(explainer.expected_value, shap_values_single, subset_df.iloc[5,:],
                text_rotation=20,show=False,
                matplotlib=True)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_explain_2.svg", 
            dpi=300,
            bbox_inches="tight",
            format="svg")

In [None]:
shap.force_plot(explainer.expected_value, shap_values, subset_df)

In [None]:
shap.dependence_plot("num__AGE", shap_values, subset_df, show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_dep_plot_age.svg",
            bbox_inches="tight",
            dpi=300,
            format="svg",)


In [None]:
shap.dependence_plot("bin__TP53_alt", shap_values, subset_df, show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_dep_plot_TP53.svg",
            bbox_inches="tight",
            dpi=300, format="svg")

In [None]:
shap.dependence_plot("num__ALBUMIN", shap_values, subset_df, show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_dep_plot_ALBUMIN.svg",
            dpi=300,
            bbox_inches="tight",
            format="svg")

In [None]:
shap.dependence_plot("cat__SAMPLE_TYPE_Metastasis", shap_values, subset_df, show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_dep_plot_Metastasis.svg",
            dpi=300,
            bbox_inches="tight",
            format="svg")

In [None]:
shap.dependence_plot("num__CHLORIDE", shap_values, subset_df, show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_summary_dep_plot_Chloride.svg",
            dpi=300,
            bbox_inches="tight",
            format="svg")

In [None]:
shap.dependence_plot("num__CHEMO_antimetabolite", shap_values, subset_df, show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_dep_plot_CHEMO_antimetabolite.svg",
            dpi=300,
            bbox_inches="tight",
            format="svg")

In [None]:
shap.dependence_plot("num__CHEMO_platin", shap_values, subset_df, show=False)
plt.savefig(get_parent_dir() / f"visualizations/{feature}_shap_dep_plot_CHEMO_platin.svg",
            bbox_inches="tight",
            dpi=300, format="svg")