# Dancing with censored data: How to survive with explainable survival analysis?
**ML in PL Conference 2023**


Mateusz Krzyziński, Mikołaj Spytek (**MI2.AI**)


## Agenda

- 9:00 - 9:15 - Introduction, technicalities
- 9:15 - 10:00 - ML in survival analysis, creating models `(Python and R)`
- 10:00 - 10:10 - *break*

- 10:10 - 10:20 - XAI in survival analysis
- 10:20 - 10:45 - SurvLIME `(Python and R)`
- 10:45 - 11:30 - SurvSHAP(t) and its aggregations `(Python and R)`
- 11:30 - 11:40 - *break*

- 11:40 - 11:50 - creating explainers for any models `(R with Python via reticulate)`
- 11:50 - 12:15 - global performance & global variable importance explanations `(R)`
- 12:15 - 12:30 - local variable dependence explanations `(R)`
- 12:30 - 12:45 - global variable dependence explanations `(R)`
- 12:45 - 13:00 - summary and Q&A


## Why we need explanations for survival models?

**Complex machine learning survival models are used more and more often (also in healthcare and scientific research) because**:
- Cox PH model cannot model complex dependencies and has the strong assumption of proportional hazards that is often not met,
- ML models are more flexible and can model complex dependencies, but they are often treated as black boxes.


<img src="images/need_quotes.png" width="1000">






### More about medical context 

The **overoptimistic** use of AI models in medicine **→** The need for a method of **validation** other than just performance validation.
> There exists a notable tendency among some researchers steeped in ML to report only the model performance metrics. XAI techniques offer a means to validate and explore survival models, effectively complementing performance validation. They enable to shift of the focus towards model reliability and accountability rather than solely optimizing model performance metrics. This shift is crucial in mitigating the risk of over-optimistic assessments of complex models, thereby promoting responsible and informed model usage.

The **complexity** and **lack of interpretability** of AI models **→** The need for a method of **examining** models that would make it possible to understand their operation.
> Researchers rely on survival models to analyze the effectiveness of new therapies, explore genetic markers associated with diseases, or evaluate the impact of lifestyle factors on health outcomes. Given the vast implications of these models, ensuring that their predictions are not only accurate but also interpretable becomes a matter of utmost importance





## XAI as model exploration stack

- There are many different types of (survival) models, and each of them has its own structure. 
- Altough model-specific approaches are possible, it is much more convenient to use a model-agnostic approach. 
- In this approach, the model is wrapped in an abstract object, and the explanation methods are applied to this object (they are usually based on the analysis of the model's predictions). 

<img src="images/piramide.png" width="800">

**Read more:** https://ema.drwhy.ai


## Tools
- In R, we will use the [`survex`](https://modeloriented.github.io/survex/) package that follows similar principles.
<br/>
<img src="images/survex.png" width="800"><br/>
Spytek, M., Krzyziński, M., Langbein, S. H., Baniecki, H., Wright, M. N., & Biecek, P. (2023). **survex: an R package for explaining machine learning survival models**. arXiv preprint arXiv:2308.16113. https://arxiv.org/abs/2308.16113

- In Python, we will use the [`survlimepy`](https://github.com/imatge-upc/SurvLIMEpy) and [`survshap`](https://github.com/MI2DataLab/survshap) packages. 
<br/>

## Important!

**Warning!** Due to some unusual dependencies of the `survlimepy` package this notebook needs to be run with `Python 3.10` and the requirements listed in `survlimepy_requirements.txt`.

***This is different than all other Python requirements for the other notebooks!!!***

In [None]:
# pip install survlimepy_requirements.txt

## Loading packages

In [None]:
import random
import pandas as pd
from matplotlib import pyplot as plt

import warnings

warnings.filterwarnings("ignore")

from survlimepy.survlime_explainer import SurvLimeExplainer
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.datasets import get_x_y

## Loading data

In [None]:
df = pd.read_csv("datasets/veterans.csv")
df = df.dropna()
df = pd.get_dummies(df, drop_first=True, dtype=float)

X, y = get_x_y(df, ["status", "time"], pos_label=1)

## Creating models

In [None]:
cph_model = CoxPHSurvivalAnalysis()
cph_model.fit(X, y)

In [None]:
random.seed(123)
rsf_model = RandomSurvivalForest(n_estimators=100)
rsf_model.fit(X, y)

## Creating explainers

In the design of the `survlimepy`

In [None]:
cph_explainer = SurvLimeExplainer(
    training_features= X,
    training_events = [tmp[0] for tmp in y],
    training_times = [tmp[1] for tmp in y],
    model_output_times = cph_model.event_times_,
    random_state = 42)


rsf_explainer = SurvLimeExplainer(
    training_features= X,
    training_events = [tmp[0] for tmp in y],
    training_times = [tmp[1] for tmp in y],
    model_output_times = rsf_model.event_times_,
    random_state = 42)

## SurvLIME

- based on LIME (Local Interpretable Model-agnostic Explanations) method 
- local explanation method (for one observation & prediction)
- **main idea:** approximate the black-box model locally with a simple interpretable model - apply the Cox proportional hazards model to approximate the black-box model
- **the explanation:** coefficients of the locally fitted Cox model
- those coefficients are calculated based on distance between cumulative hazard functions of black-box model and surrogate model
$\min_{\mathbf{b}}\sum_{k=1}^{N}w_{k}\sum_{j=0}^{m}v_{kj}^{2}\left(  \ln
H_{j}(\mathbf{x}_{k})-\ln H_{0j}-\mathbf{b}^{\mathrm{T}}\mathbf{x}_{k}\right)
^{2}\left(  t_{j+1}-t_{j}\right)$

<img src="images/survlime.png" width="600"/>

Kovalev, M. S., Utkin, L. V., & Kasimov, E. M. (2020). **SurvLIME: A method for explaining machine learning survival models**. Knowledge-Based Systems, 203, 106164. https://doi.org/10.1016/j.knosys.2020.106164
Pachón-García, C., Hernández-Pérez, C., Delicado, P., & Vilaplana, V. (2024). **SurvLIMEpy: A Python package implementing SurvLIME**. Expert Systems with Applications, 237, 121620. https://doi.org/10.1016/j.eswa.2023.121620


In [None]:
patientA = X.iloc[23]
patientA

In [None]:
pred_cph = cph_model.predict_survival_function(pd.DataFrame([patientA]))
pred_rsf = rsf_model.predict_survival_function(pd.DataFrame([patientA]))

plt.plot(pred_cph[0].x, pred_cph[0].y, label="CoxPH")
plt.plot(pred_rsf[0].x, pred_rsf[0].y, label="RSF")
plt.xlim(0, 300)
plt.legend()

In [None]:
cph_explanation = cph_explainer.explain_instance(
    data_row = patientA,
    predict_fn=cph_model.predict_cumulative_hazard_function,
    num_samples = 100)

print(cph_explanation)

cph_explainer.plot_weights(figsize=(6, 4))

In [None]:
cph_explanation * patientA 

### Parameters
- kernel width - how to calculate weights for observations generated as neighbours (**bug in the current version of the package - not possible to change this parameter**) 
- number of neighbours - how many neighbours to generate
- functional norm - which distance between cumulative hazard functions is optimized

In [None]:
cph_explainer = SurvLimeExplainer(
    training_features= X,
    training_events = [tmp[0] for tmp in y],
    training_times = [tmp[1] for tmp in y],
    model_output_times = cph_model.event_times_,
    random_state = 42,
    functional_norm = "inf")

In [None]:
cph_explanation = cph_explainer.explain_instance(
    data_row = patientA,
    predict_fn = cph_model.predict_cumulative_hazard_function,
    num_samples = 100)

print(cph_explanation)

cph_explainer.plot_weights(figsize=(6, 4))

### Problems for non-linear models
- Predictions from surrogate model are often far from predictions of black-box model. The reason is that the surrogate model is too simple to approximate more complex black-box models. 
- Explanations are not stable. 

In [None]:
rsf_explanation = rsf_explainer.explain_instance(
    data_row = patientA,
    predict_fn = rsf_model.predict_cumulative_hazard_function,
    num_samples = 100)

print(rsf_explanation)

rsf_explainer.plot_weights(figsize=(6, 4))

In [None]:
rsf_explanation = rsf_explainer.explain_instance(
    data_row = patientA,
    predict_fn = rsf_model.predict_cumulative_hazard_function,
    num_samples = 100)

print(rsf_explanation)

rsf_explainer.plot_weights(figsize=(6, 4))