# Dancing with censored data: How to survive with explainable survival analysis?
**ML in PL Conference 2023**


Mateusz Krzyziński, Mikołaj Spytek (**MI2.AI**)


## Agenda

- 9:00 - 9:15 - Introduction, technicalities
- 9:15 - 10:00 - ML in survival analysis, creating models `(Python and R)`
- 10:00 - 10:10 - *break*

- 10:10 - 10:20 - XAI in survival analysis
- 10:20 - 10:45 - SurvLIME `(Python and R)`
- 10:45 - 11:30 - SurvSHAP(t) and its aggregations `(Python and R)`
- 11:30 - 11:40 - *break*

- 11:40 - 11:50 - creating explainers for any models `(R with Python via reticulate)`
- 11:50 - 12:15 - global performance & global variable importance explanations `(R)`
- 12:15 - 12:30 - local variable dependence explanations `(R)`
- 12:30 - 12:45 - global variable dependence explanations `(R)`
- 12:45 - 13:00 - summary and Q&A


## Why we need explanations for survival models?

**Complex machine learning survival models are used more and more often (also in healthcare and scientific research) because**:
- Cox PH model cannot model complex dependencies and has the strong assumption of proportional hazards that is often not met,
- ML models are more flexible and can model complex dependencies, but they are often treated as black boxes.


<img src="images/need_quotes.png" width="1000">






**More about medical context**

The **overoptimistic** use of AI models in medicine **→** The need for a method of **validation** other than just performance validation.
> There exists a notable tendency among some researchers steeped in ML to report only the model performance metrics. XAI techniques offer a means to validate and explore survival models, effectively complementing performance validation. They enable to shift of the focus towards model reliability and accountability rather than solely optimizing model performance metrics. This shift is crucial in mitigating the risk of over-optimistic assessments of complex models, thereby promoting responsible and informed model usage.

The **complexity** and **lack of interpretability** of AI models **→** The need for a method of **examining** models that would make it possible to understand their operation.
> Researchers rely on survival models to analyze the effectiveness of new therapies, explore genetic markers associated with diseases, or evaluate the impact of lifestyle factors on health outcomes. Given the vast implications of these models, ensuring that their predictions are not only accurate but also interpretable becomes a matter of utmost importance





## XAI as model exploration stack

- There are many different types of (survival) models, and each of them has its own structure. 
- Altough model-specific approaches are possible, it is much more convenient to use a model-agnostic approach. 
- In this approach, the model is wrapped in an abstract object, and the explanation methods are applied to this object (they are usually based on the analysis of the model's predictions). 

<img src="images/piramide.png" width="800">

**Read more:** https://ema.drwhy.ai


## Tools
- In R, we will use the [`survex`](https://modeloriented.github.io/survex/) package that follows similar principles.
<br/>
<img src="images/survex.png" width="800"><br/>
Spytek, M., Krzyziński, M., Langbein, S. H., Baniecki, H., Wright, M. N., & Biecek, P. (2023). **survex: an R package for explaining machine learning survival models**. arXiv preprint arXiv:2308.16113. https://arxiv.org/abs/2308.16113

- In Python, we will use the [`survlimepy`](https://github.com/imatge-upc/SurvLIMEpy) and [`survshap`](https://github.com/MI2DataLab/survshap) packages. 
<br/>

## Loading packages

In [None]:
library(survival)
library(survex)
library(ranger)

In [None]:
options(jupyter.plot_scale=1, repr.plot.width = 6, repr.plot.height = 4)

## Creating models

In [None]:
df <- read.csv("datasets/lung_dataset.csv")
df <- df[complete.cases(df),]
head(df)

In [None]:
cph_model <- coxph(Surv(time, status) ~ ., data = df, x = TRUE, model = TRUE)

In [None]:
set.seed(123)
rsf_model <- ranger(Surv(time, status) ~ ., data = df)

## Creating explainers

For some models explainers are created fully automatically, extracting all necessary information along with the predict functions from the model object.

For example, for Cox PH model created with `survival` package, we have to only set `x = TRUE` and `model = TRUE` in `survival::coxph` function when creating model.

In [None]:
cph_explainer <- explain(cph_model)

For other models, we have to provide additional informations to `explain` function, for example dataset. 

Of course, dataset always can be provided manually. It can be useful when we want to use test set or specific subset of data for creating explanations (for example to answer different questions - context in explanations is crucial).  

In [None]:
rsf_explainer <- explain(rsf_model,
                        data = df[3:9], 
                        y = Surv(df$time, df$status))

In [None]:
#?explain

In [None]:
cph_explainer$times

## SurvLIME

- based on LIME (Local Interpretable Model-agnostic Explanations) method 
- local explanation method (for one observation & prediction)
- **main idea:** approximate the black-box model locally with a simple interpretable model - apply the Cox proportional hazards model to approximate the black-box model
- **the explanation:** coefficients of the locally fitted Cox model
- those coefficients are calculated based on distance between cumulative hazard functions of black-box model and surrogate model
$\min_{\mathbf{b}}\sum_{k=1}^{N}w_{k}\sum_{j=0}^{m}v_{kj}^{2}\left(  \ln
H_{j}(\mathbf{x}_{k})-\ln H_{0j}-\mathbf{b}^{\mathrm{T}}\mathbf{x}_{k}\right)
^{2}\left(  t_{j+1}-t_{j}\right)$

<img src="images/survlime.png" width="600"/>

Kovalev, M. S., Utkin, L. V., & Kasimov, E. M. (2020). **SurvLIME: A method for explaining machine learning survival models**. Knowledge-Based Systems, 203, 106164. https://doi.org/10.1016/j.knosys.2020.106164
Pachón-García, C., Hernández-Pérez, C., Delicado, P., & Vilaplana, V. (2024). **SurvLIMEpy: A Python package implementing SurvLIME**. Expert Systems with Applications, 237, 121620. https://doi.org/10.1016/j.eswa.2023.121620


In [None]:
patientA <- df[42, 3:9]
patientA 

In [None]:
set.seed(42)
survlime_cph_patientA <- predict_parts(cph_explainer, patientA, type = "survlime")

#### Default results (coefficients)

In [None]:
plot(survlime_cph_patientA, type = "coefficients")

In [None]:
survlime_cph_patientA$result 

### Local importance (coefficients multiplied by features values)

In [None]:
plot(survlime_cph_patientA, type = "local_importance")

### Parameters
- kernel width - how to calculate weights for observations generated as neighbours 
- number of neighbours - how many neighbours to generate
- which variables should be treated as categorical (different perturbation method)

In [None]:
set.seed(42)
survlime_cph_patientA_kw5 <- predict_parts(cph_explainer, patientA, type = "survlime", kernel_width = 5)

In [None]:
plot(survlime_cph_patientA_kw5)

In [None]:
set.seed(42)
survlime_cph_patientA_catvars <- predict_parts(cph_explainer, patientA, type = "survlime",
                                    categorical_variables = c("sex", "ph.ecog"))

In [None]:
plot(survlime_cph_patientA_catvars)

### Problems for non-linear models
- Predictions from surrogate model are often far from predictions of black-box model. The reason is that the surrogate model is too simple to approximate more complex black-box models. 
- Explanations are not stable. 

In [None]:
set.seed(42)
survlime_rsf_patientA <- predict_parts(rsf_explainer, patientA, type = "survlime")

In [None]:
plot(survlime_rsf_patientA)

In [None]:
set.seed(12)
survlime_rsf_patientA <- predict_parts(rsf_explainer, patientA, type = "survlime")

In [None]:
plot(survlime_rsf_patientA)

## SurvSHAP(t)

- based on SHAP (SHapley Additive exPlanations) method 
- local explanation method (for one observation & prediction)
- **main idea:** allow for time-dependent explainability better suited to complex models
- **the explanation:** curves showing the attribution of each feature to the prediction over time (can be aggregated to single numbers)

- **formula:**
$$\phi_{t}(\mathbf{x}_*, d) = \frac{1}{|\Pi|} \sum_{\pi \in \Pi} e_{t, \mathbf{x}_*}^{\mathrm{before}(\pi, d) \cup \{d\}} - e_{t, \mathbf{x}_*}^{\mathrm{before}(\pi, d)},$$ 
where $\Pi$ is a set of all permutations of $p$ variables and $\mathrm{before}(\pi, d)$ denotes a subset of predictors that are before $d$ in the ordering $\pi \in \Pi$.

- **local variable importance:**
$    \psi(\textbf{x}_*, d) = \int_0^{t_m} \left| \phi_{t}(\textbf{x}_*, d)\right| \, \mathrm{d}w(t) $

<img src="images/survshap.png" width="800"/>


Krzyziński, M., Spytek, M., Baniecki, H., & Biecek, P. (2023). **SurvSHAP(t): Time-dependent explanations of machine learning survival models**. Knowledge-Based Systems, 262, 110234. https://doi.org/10.1016/j.knosys.2022.110234

In [None]:
set.seed(42)
survshap_cph_patientA <- predict_parts(cph_explainer, patientA, type = "survshap")

In [None]:
plot(survshap_cph_patientA)

In [None]:
survshap_rsf_patientA <- predict_parts(rsf_explainer, patientA, type = "survshap")

In [None]:
plot(survshap_rsf_patientA)

In [None]:
print(survshap_rsf_patientA$aggregate)

### Parameters
- number of samples in the background dataset 
- aggregation method

In [None]:
survshap_rsf_patientA_N50 <- predict_parts(rsf_explainer, patientA, type = "survshap", N=10)

In [None]:
plot(survshap_rsf_patientA_N50)

### Using other forms of predictions   

In [None]:
survshap_rsf_patientA_shap_risk <- predict_parts(rsf_explainer, patientA, type = "shap", output_type = "risk")

In [None]:
plot(survshap_rsf_patientA_shap_risk)

In [None]:
survshap_rsf_patientA_shap_chf <- predict_parts(rsf_explainer, patientA, type = "survshap", output_type = "chf")

In [None]:
plot(survshap_rsf_patientA_shap_chf)

## SurvSHAP(t) aggregations 
- global (model) explanations based on local explanations
- can be calculated for each model but it is computationally expensive 
- for tree-based models can be calculated faster using TreeSHAP algorithm (`treeshap` package)

In [None]:
survshap_rsf_model <- model_survshap(rsf_explainer, rsf_explainer$data, calculation_method = "treeshap", verbose = FALSE)

In [None]:
plot(survshap_rsf_model)

In [None]:
plot(survshap_rsf_model, geom = "beeswarm")

In [None]:
plot(survshap_rsf_model, geom = "profile", variable = "pat.karno", color_variable = "sex")

In [None]:
plot(survshap_rsf_model, geom = "curves", variable = "pat.karno")

In [None]:
plot(survshap_rsf_model, geom = "curves", variable = "pat.karno", boxplot = TRUE, coef = 3)