# Dancing with censored data: How to survive with explainable survival analysis?
**ML in PL Conference 2023**


Mateusz Krzyziński, Mikołaj Spytek (**MI2.AI**)


## Agenda

- 9:00 - 9:15 - Introduction, technicalities
- 9:15 - 10:00 - ML in survival analysis, creating models `(Python and R)`
- 10:00 - 10:10 - *break*

- 10:10 - 10:20 - XAI in survival analysis
- 10:20 - 10:45 - SurvLIME `(Python and R)`
- 10:45 - 11:30 - SurvSHAP(t) and its aggregations `(Python and R)`
- 11:30 - 11:40 - *break*

- 11:40 - 11:50 - creating explainers for any models `(R with Python via reticulate)`
- 11:50 - 12:15 - global performance & global variable importance explanations `(R)`
- 12:15 - 12:30 - local variable dependence explanations `(R)`
- 12:30 - 12:45 - global variable dependence explanations `(R)`
- 12:45 - 13:00 - summary and Q&A


## Loading packages

In [None]:
library(survival)
library(survex)
library(ingredients)
library(reticulate)
library(ranger)

In [None]:
options(jupyter.plot_scale=1, repr.plot.width = 6, repr.plot.height = 4)

## Creating `survex` explainer for a `scikit-survival` model
`survex` offers automatic explainer creation for `scikit-survival` models loaded into R via `reticulate` package.




In [None]:
use_python("./tutorial-env/bin/python")

In [None]:
# Python packages imports
py <- import_builtins()
sksurv <- import("sksurv")
pickle <- import("pickle")

In [None]:
rsf_sksurv_model <- pickle$load(py$open("./rsf_sksurv_model.pkl", "rb"))

In [None]:
class(rsf_sksurv_model)

In [None]:
df = read.csv("datasets/lung_dataset.csv")
df = df[complete.cases(df), ]

In [None]:
rsf_sksurv_model$predict_survival_function(df[1, 3:9])

In [None]:
rsf_sksurv_explainer <- explain(rsf_sksurv_model, data = df[,3:9], y = Surv(df$time, df$status), label = "sksurv_rf")

In [None]:
rsf_sksurv_explainer$times <- rsf_sksurv_explainer$times[1:49]

In [None]:
surv_probs <- predict(rsf_sksurv_explainer, df[1:2, 3:9], type = "survival")
plot(rsf_sksurv_explainer$times, surv_probs[1,], type = "l", ylab = "Survival probability", xlab = "Time", ylim = c(0, 1))
lines(rsf_sksurv_explainer$times, surv_probs[2,], col = "red")

## Creating `survex` explainer for any model
`survex` offers support for creating explainers for any model, not only those with automatic explainer creation.

To be precise, `explainer` object is always created but to be useful it needs to be filled with data and prediction functions. Due to the connections between different forms of survival models' predictions, it is enough to provide just one function that is time-dependent (method for predicting survival functions or cumulative hazard functions) and `survex` will automatically create the rest of the functions.

In [None]:
set.seed(123)
dummy_survival_function <- stepfun(x=seq(0, 1000, length.out=100), y=seq(1, 0, length.out=101))
dummy_survival_function2 <- stepfun(x=seq(0, 1000, length.out=100), y=seq(1, 0.5, length.out=101))
tmp_model <- function(observation, times){
    if (observation[1] %% 2 != 0){
        return(dummy_survival_function(times))
    } else {
        return(dummy_survival_function2(times))
    }
}

In [None]:
tmp_model(c(1, 150, 500), c(0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000))

In [None]:
tmp_model(c(2, 150, 500), c(0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000))

In [None]:
explain(tmp_model) # nothing useful

In [None]:
explain_survival(tmp_model) # also nothing useful

In [None]:
mock_explainer <- explain_survival(tmp_model, data = df[,3:9], y = Surv(df$time, df$status),
        predict_survival_function = function(model, newdata, times) {
            t(apply(newdata, 1, function(x) model(x, times)))
            # here should be the code to predict survival function at times for newdata
        },
        label = "mock_model") 

In [None]:
predict(mock_explainer, df[1:5, 3:9], output_type = "survival")

In [None]:
print(predict(mock_explainer, df[1:5, 3:9], output_type = "risk"))

## Global performance - `model_performance`
There are many performance metrics for survival models. Generally, they can be divided into two groups:
- functional - their result is a function of time (`"time-dependent"` in `survex`), e.g.:
    - cumulative/dynamic AUC <br> 
    (cumulative cases are those with the event time $\leq t$ and dynamic cases are all individuals that experienced an event after time $t$),
    - Brier score <br>
    (mean squared error between predicted survival function values and survival outcomes at time $t$).

    **Note:** inverse probability of censoring weight is used for all functional metrics.

- aggregated - their result is a single number (`"scalar"` in `survex`), e.g.:
    - concordance index <br>
    (proportion of pairs of individuals where the model correctly predicted the order of event times/risks),
    - integrated Brier score,
    - integrated/dynamic AUC. 

These metrics are implemented in `survex` but any other metric can be easily added (there is for example utility function for adapting loss functions from the `mlr3proba` package).

In [None]:
set.seed(123)
rsf_ranger_model <- ranger(Surv(time, status) ~ ., data = df)
rsf_ranger_explainer <- explain(rsf_ranger_model,
                                data = df[3:9], 
                                y = Surv(df$time, df$status),
                                times = rsf_sksurv_explainer$times)

In [None]:
rsf_sksurv_performance <- model_performance(rsf_sksurv_explainer)
rsf_ranger_performance <- model_performance(rsf_ranger_explainer)
mock_model_performance <- model_performance(mock_explainer)

In [None]:
plot(rsf_sksurv_performance, rsf_ranger_performance)

In [None]:
plot(rsf_sksurv_performance, rsf_ranger_performance, mock_model_performance)

In [None]:
plot(rsf_sksurv_performance, rsf_ranger_performance, mock_model_performance, metrics_type = "scalar")

## Global variable importance - `model_parts`

- Some models have built-in variable importance measures (e.g. random forests) but we want to have a unified method for all models.
- We have already covered global aggregations of SurvSHAP(t) that can be used as variable importance measures.
- Another model-agnostic approach is permutation-based importance. 
- Permutation-based importance uses a selected loss function $\mathcal{L}$ to measure the importance of a $j$-th variable for the model as the difference between the loss function value for the original data and the data with the $j$-th variable permuted:
  $$PVI^j(t) = \frac{1}{B} \sum_{i=1}^B \mathcal{L}(y, f(X^{*j},t))    - \mathcal{L}(y, f(X, t)),$$
  where $B$ is the number of considered permutations, $\mathcal{L}(y, f(X,t))$ is a loss function, and $X^{*j}$ is a model input with the $j$-th column permuted.

In [None]:
ranger_vimp_brier <- model_parts(rsf_ranger_explainer) 

In [None]:
plot(ranger_vimp_brier)

In [None]:
ranger_vimp_cd_auc <- model_parts(rsf_ranger_explainer, loss_function = survex::loss_one_minus_cd_auc) 

In [None]:
plot(ranger_vimp_cd_auc)

In [None]:
ranger_vimp_brier_integrated <- model_parts(rsf_ranger_explainer, loss_function = survex::loss_integrated_brier_score) 

In [None]:
plot(ranger_vimp_brier_integrated)

## Local variable dependence - `predict_profile`

- These are local explanations that show how the model prediction for the selected observation changes with the change of a single variable.
- The method is called individual conditional expectation (ICE) and is defined as:
  $$\text{ICE}^j_{\mathbf{x}_*}(z, t) = f(\mathbf{x}_*^{j=z}, t),$$
  where $\mathbf{x}_*$ is the selected observation, $z$ is the value of the $j$-th variable, and $f(\mathbf{x}_*^{j=z}, t)$ is the model prediction at time $t$ for the observation $\mathbf{x}_*$ but with the $j$-th variable set to $z$.

In [None]:
patientA <- df[42, 3:9] 
patientA

In [None]:
profiles_patientA <- predict_profile(rsf_ranger_explainer, patientA)

In [None]:
plot(profiles_patientA, variables = "age")

In [None]:
plot(profiles_patientA, variables = "age", numerical_plot_type = "contours") 

In [None]:
plot(profiles_patientA, variables = "sex") 

In [None]:
profiles_patientA_catvars <- predict_profile(rsf_ranger_explainer, patientA, categorical_variables = c("sex", "ph.ecog"))
plot(profiles_patientA_catvars, variables = "sex")

In [None]:
plot(profiles_patientA, geom = "variable", variables = "age")

In [None]:
plot(profiles_patientA, geom = "variable", variables = "age", times = rsf_ranger_explainer$times[seq(5, 49, 5)])

In [None]:
plot(profiles_patientA, geom = "variable", variables = "age", times = rsf_ranger_explainer$times, marginalize_over_time = TRUE)

## Global variable dependence - `model_profile`

- These are global explanations that show how the model average predictions depend on the changes of a single variable (or two variables).
- Two specific methods are available: 
    - partial dependence plots (PDP) <br>
    (calculated as the average of ICE explanations for a given variable),
         $$\text{PDP}^j(z, t) = \frac{1}{n}\sum_{i=1}^n f(\textbf{x}_i^{j|=z}, t),$$
    - accumulated local effects (ALE) <br>
    (calculated differently, by analyzing how model predictions change in small 'windows' while taking into account dependences between variables). 


In [None]:
profiles_rsf_pdp <- model_profile(rsf_ranger_explainer)

In [None]:
plot(profiles_rsf_pdp, variables = "wt.loss")

In [None]:
plot(profiles_rsf_pdp, variables = "wt.loss", numerical_plot_type = "contours")

In [None]:
profiles_rsf_ale <- model_profile(rsf_ranger_explainer, type = "accumulated")

In [None]:
plot(profiles_rsf_ale, variables = "wt.loss")

In [None]:
plot(profiles_rsf_ale, variables = "wt.loss", numerical_plot_type = "contours")

In [None]:
plot(profiles_rsf_pdp, geom="variable", variables = "wt.loss", times = rsf_ranger_explainer$times[c(5, 20, 35)], plot_type = "pdp+ice")

In [None]:
plot(profiles_rsf_pdp, geom="variable", variables = "wt.loss", times = rsf_ranger_explainer$times[seq(5, 49, 5)], plot_type = "pdp")

In [None]:
plot(profiles_rsf_ale, geom="variable", variables = "wt.loss", times = rsf_ranger_explainer$times[seq(5, 49, 5)])

In [None]:
profiles_rsf_catvars <- model_profile(rsf_ranger_explainer, type = "partial", categorical_variables = c("sex", "ph.ecog"))

In [None]:
plot(profiles_rsf_catvars, geom = "variable", variables = "ph.ecog", times = rsf_ranger_explainer$times[c(5, 20, 35)])

In [None]:
profile_2d_rsf_pdp <- model_profile_2d(rsf_ranger_explainer,
                                variables = list(c("age", "wt.loss")),
                                grid_points = 10)

profile_2d_rsf_ale <- model_profile_2d(rsf_ranger_explainer,
                                variables = list(c("age", "wt.loss")),
                                grid_points = 25,
                                type = "accumulated")

In [None]:
plot(profile_2d_rsf_pdp, times = rsf_ranger_explainer$times, marginalize_over_time = TRUE)

In [None]:
plot(profile_2d_rsf_ale, times = rsf_ranger_explainer$times, marginalize_over_time = TRUE)

## Q&A