# Dancing with censored data: How to survive with explainable survival analysis?
**ML in PL Conference 2023**


Mateusz Krzyziński, Mikołaj Spytek (**MI2.AI**)


## Agenda

- 9:00 - 9:15 - Introduction, technicalities
- 9:15 - 10:00 - ML in survival analysis, creating models `(Python and R)`
- 10:00 - 10:10 - *break*

- 10:10 - 10:20 - XAI in survival analysis
- 10:20 - 10:45 - SurvLIME `(Python and R)`
- 10:45 - 11:30 - SurvSHAP(t) and its aggregations `(Python and R)`
- 11:30 - 11:40 - *break*

- 11:40 - 11:50 - creating explainers for any models `(R with Python via reticulate)`
- 11:50 - 12:15 - global performance & global variable importance explanations `(R)`
- 12:15 - 12:30 - local variable dependence explanations `(R)`
- 12:30 - 12:45 - global variable dependence explanations `(R)`
- 12:45 - 13:00 - summary and Q&A


## Loading packages

In [None]:
import pandas as pd
import random
from sksurv.datasets import get_x_y
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest

from survshap import SurvivalModelExplainer, ModelSurvSHAP, PredictSurvSHAP

## Loading data

In [None]:
df = pd.read_csv("datasets/veterans.csv")
df = df.dropna()
df = pd.get_dummies(df, drop_first=True, dtype=float)

X, y = get_x_y(df, ["status", "time"], pos_label=1)

## Creating models

In [None]:
cph_model = CoxPHSurvivalAnalysis()
cph_model.fit(X, y)

In [None]:
random.seed(123)
rsf_model = RandomSurvivalForest(n_estimators=100)
rsf_model.fit(X, y)

## SurvSHAP(t)

- based on SHAP (SHapley Additive exPlanations) method 
- local explanation method (for one observation & prediction)
- **main idea:** allow for time-dependent explainability better suited to complex models
- **the explanation:** curves showing the attribution of each feature to the prediction over time (can be aggregated to single numbers)

- **formula:**
$$\phi_{t}(\mathbf{x}_*, d) = \frac{1}{|\Pi|} \sum_{\pi \in \Pi} e_{t, \mathbf{x}_*}^{\mathrm{before}(\pi, d) \cup \{d\}} - e_{t, \mathbf{x}_*}^{\mathrm{before}(\pi, d)},$$ 
where $\Pi$ is a set of all permutations of $p$ variables and $\mathrm{before}(\pi, d)$ denotes a subset of predictors that are before $d$ in the ordering $\pi \in \Pi$.

- **local variable importance:**
$    \psi(\textbf{x}_*, d) = \int_0^{t_m} \left| \phi_{t}(\textbf{x}_*, d)\right| \, \mathrm{d}w(t) $

<img src="images/survshap.png" width="800"/>


Krzyziński, M., Spytek, M., Baniecki, H., & Biecek, P. (2023). **SurvSHAP(t): Time-dependent explanations of machine learning survival models**. Knowledge-Based Systems, 262, 110234. https://doi.org/10.1016/j.knosys.2022.110234

In [None]:
patientA = X.iloc[23]
patientA

In [None]:
survshap_cph_explainer = SurvivalModelExplainer(cph_model, X, y)
survshap_rsf_explainer = SurvivalModelExplainer(rsf_model, X, y)

In [None]:
survshap_cph_explanation = PredictSurvSHAP()
survshap_cph_explanation.fit(survshap_cph_explainer, patientA)

In [None]:
survshap_cph_explanation.result

In [None]:
survshap_cph_explanation.plot()

In [None]:
survshap_cph_explanation.plot(x_range=[0, 300])

In [None]:
survshap_rsf_explanation = PredictSurvSHAP()
survshap_rsf_explanation.fit(survshap_rsf_explainer, patientA)

In [None]:
survshap_rsf_explanation.plot(x_range=[0, 300])

In [None]:
survshap_rsf_explanation.simplified_result

### Using other forms of predictions   

In [None]:
survshap_rsf_explanation_chf = PredictSurvSHAP(function_type = "chf")
survshap_rsf_explanation_chf.fit(survshap_rsf_explainer, patientA)
survshap_rsf_explanation_chf.plot()

## SurvSHAP(t) aggregations 
- global (model) explanations based on local explanations
- can be calculated for each model but it is computationally expensive 
- for tree-based models can be calculated faster using TreeSHAP algorithm (`treeshap` package)

In [None]:
model_survshap = ModelSurvSHAP(calculation_method="treeshap")

In [None]:
model_survshap.fit(survshap_rsf_explainer, X.iloc[:50])

In [None]:
model_survshap.plot_mean_abs_shap_values(x_range=[0, 300])

In [None]:
model_survshap.result

In [None]:
model_survshap.plot_shap_lines_for_all_individuals(variable = "karno", x_range=[0, 300])

In [None]:
model_survshap.plot_shap_lines_for_all_individuals(variable = "karno", x_range=[0, 300], boxplot = True)