# Dancing with censored data: How to survive with explainable survival analysis?
**ML in PL Conference 2023**


Mateusz Krzyziński, Mikołaj Spytek (**MI2.AI**)


## Agenda

- 9:00 - 9:15 - Introduction, technicalities
- 9:15 - 10:00 - ML in survival analysis, creating models `(Python and R)`
- 10:00 - 10:10 - *break*

- 10:10 - 10:20 - XAI in survival analysis
- 10:20 - 10:45 - SurvLIME `(Python and R)`
- 10:45 - 11:30 - SurvSHAP(t) and its aggregations `(Python and R)`
- 11:30 - 11:40 - *break*

- 11:40 - 11:50 - creating explainers for any models `(R with Python via reticulate)`
- 11:50 - 12:15 - global performance & global variable importance explanations `(R)`
- 12:15 - 12:30 - local variable dependence explanations `(R)`
- 12:30 - 12:45 - global variable dependence explanations `(R)`
- 12:45 - 13:00 - summary and Q&A


## Technicalities

### Installing and loading packages

In [None]:
# pip install -r requirements.txt

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import matplotlib as mpl

mpl.rcParams["axes.spines.right"] = False
mpl.rcParams["axes.spines.top"] = False

def plot(survival_prob):
    fig, ax = plt.subplots()
    ax.plot(survival_prob[0].x, survival_prob[0].y, label="Patient A")
    ax.plot(survival_prob[1].x, survival_prob[1].y, label="Patient B")
    ax.set_xlabel("Time")
    ax.set_ylabel("Survival probability")
    ax.set_ylim(0, 1)
    ax.legend()

## Intro to survival analysis

### What is survival analysis?
- type of supervised learning task
- also known as time-to-event analysis, reliability analysis, duration analysis, ... 
- **data modality:** mostly tabular, censored data (the most popular case is right censoring)

<img src="images/censoring.png" width="600">

- **output:** survival probability distribution (can be represented by survival function)

<img src="images/survival_probs.png" width="600">


### Mathematical background

#### Data format
- instance $i$: $(\textbf{x}_i, y_i, \delta_i)$
- $\textbf{x}_i$ is a vector of features
-  $\delta_i$ is the indicator of event of interest's occurrence (usually 0 is for censored data, 1 for event of interest's occurrence)
-  $y_i$ stands for the observed time 
    - it can be either survival time $T_i$ when $\delta_i = 1$ 
    - or censoring time $C_i$ when $\delta_i = 0$
- we want to model the conditional distribution of $T$, not $Y$ (but we don't observe $T$ directly!)


#### Key quantities  
- **survival function:** (probability of surviving past a given time point)
 $$S(t|\mathbf{x}) = \mathbb{P}(T > t|\mathbf{X}=\mathbf{x}) = 1 - \mathbb{P}(T \leq t|\mathbf{X}=\mathbf{x})$$


- **hazard function:** (potential for an event to occur at infinitesimally small time interval $[t, \Delta t]$)
$$h(t|\mathbf{x}) = \lim_{\Delta t \to 0} \frac{\mathbb{P}(t \leq T < t + \Delta t|T \geq t, \mathbf{X}=\mathbf{x})}{\Delta t}$$

- **cumulative hazard function:**  (total accumulated risk of an event up to time $t$)
$$H(t|\mathbf{x}) = \int_0^t h(s|\mathbf{x})ds$$

- they are related by: $S(t|\mathbf{x}) = \exp(-H(t|\mathbf{x}))$





## First look at the data & statistical models

In [None]:
df = pd.read_csv('datasets/lung_dataset.csv')
df = df.dropna()
df["status"] = df["status"].astype('bool')

from sksurv.datasets import get_x_y

X, y = get_x_y(df, attr_labels=["status","time"], pos_label=1)

In [None]:
df.shape

In [None]:
df.head(10)

The data contain subjects with advanced lung cancer.

- time: survival time (in days)
- status: event status, 0=censored, 1=deceased
- age: age (in years)
- sex: 1=male, 2=female
- ph.ecog: ECOG performance score as rated by the physician: 
    - 0=asymptomatic
    - 1=symptomatic but completely ambulatory
    - 2=in bed <50% of the day
    - 3=in bed > 50% of the day but not bedbound
    - 4=bedbound
- ph.karno: Karnofsky performance score (0=bad, 100=good) rated by physician
- pat.karno: Karnofsky performance score (0=bad, 100=good) rated by patient
- meal.cal: calories consumed at meals
- wt.loss: weight loss in last six months

### Kaplan-Meier estimator 
- non-parametric estimator for survival function 
- estimates $S(t)$ only looking at the $(y_i, \delta_i)$ pairs - do not take into account features $\textbf{x}_i$
- separate estimators can be fitted for each group of data instances with the same value of categorical feature (stratifying)

Let's check how it looks like for our data, first for the small subset, then for the whole dataset.

In [None]:
tmp_df = df.loc[0:19, ["time", "status"]]
tmp_df

In [None]:
from sksurv.nonparametric import kaplan_meier_estimator

time, survival_prob = kaplan_meier_estimator(tmp_df["status"], tmp_df["time"])

fig, ax = plt.subplots()
ax.step(time, survival_prob, where="post")
ax.set_xlabel("Time")
ax.set_ylabel("Survival probability")
ax.set_ylim(0, 1)
plt.show()

In [None]:
time, survival_prob, conf_interval = kaplan_meier_estimator(
    df["status"], df["time"], conf_type="log-log"
)


fig, ax = plt.subplots()
ax.step(time, survival_prob, where="post")
ax.fill_between(
    time, conf_interval[0], conf_interval[1], step="post", alpha=0.2
)
ax.set_xlabel("Time")
ax.set_ylabel("Survival probability")
ax.set_ylim(0, 1)
plt.show()


### Cox proportional hazards model 
- semi-parametric model for hazard function
- assumes that the hazard function is a product of two parts: 
    - baseline hazard function $h_0(t)$ - does not depend on features $\textbf{x}_i$
    - exponential function of linear combination of features $\textbf{x}_i$
$$ h(t|\mathbf{x}) = h_0(t) \exp(\beta_1 x_1 + \beta_2 x_2 + \dots + \beta_p x_p)$$
- good baseline model (~*linear regression of survival analysis*)
- considered interpretable (exponentiated coefficients $\exp(\beta_j)$ can be interpreted as the hazard ratio)

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis

cph_model = CoxPHSurvivalAnalysis(ties="efron")
cph_model.fit(X, y)
cph_model.coef_

In [None]:
np.exp(cph_model.coef_)

In [None]:
patientA = df.iloc[41, 2:9]
patientB = patientA.copy()
patientB["ph.ecog"] = 2

patients = pd.DataFrame([patientA, patientB])
patients

In [None]:
mean_X = X.mean()

In [None]:
risks = np.exp(np.sum((patients - mean_X) * cph_model.coef_, axis=1)).values
risks

In [None]:
risks[1]/risks[0]

In [None]:
survival_prob = cph_model.predict_survival_function(patients)

plot(survival_prob)

## Machine learning models

### Survival trees
- extension of decision trees for survival analysis
- for each leaf node, based on the data instances in the node, one of the key quantities is calculated (e.g., cumulative hazard function)
- there are many splitting rules but the most common is the log-rank splitting rule (based on the log-rank test comparing survival distributions of two samples)

In [None]:
from sksurv.tree import SurvivalTree

tree_model = SurvivalTree()
tree_model.fit(X, y)

survival_prob = tree_model.predict_survival_function(patients)

plot(survival_prob)


### Random survival forest
- works similarly to classical random forest but uses survival trees instead of classification or regression decision trees
- predictions are based on the average of predictions of all trees

In [None]:
from sksurv.ensemble import RandomSurvivalForest

rsf_model = RandomSurvivalForest(n_estimators=100)
rsf_model.fit(X, y)

survival_prob = rsf_model.predict_survival_function(patients)

plot(survival_prob)

### Other models

<img src="images/models.png" width="1000">

**Find more here:**
- Wang, P., Li, Y., & Reddy, C. K. (2019). Machine learning for survival analysis: A survey. ACM Computing Surveys (CSUR), 51(6), 1-36. https://doi.org/10.1145/3214306
- Wiegrebe, S., Kopper, P., Sonabend, R., & Bender, A. (2023). Deep Learning for Survival Analysis: A Review. arXiv preprint arXiv:2305.14961. https://arxiv.org/abs/2305.14961