Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code snippets for survival SVM and survival GBM #88

Closed
pbiecek opened this issue Nov 2, 2023 · 1 comment
Closed

Code snippets for survival SVM and survival GBM #88

pbiecek opened this issue Nov 2, 2023 · 1 comment

Comments

@pbiecek
Copy link
Member

pbiecek commented Nov 2, 2023

Hi,
it would be great to have a vignette similar to https://modeloriented.github.io/survex/articles/mlr3proba-usage.html
but with examples for
https://journal.r-project.org/archive/2018/RJ-2018-005/RJ-2018-005.pdf
and
https://journal.r-project.org/archive/2020/RJ-2020-018/RJ-2020-018.pdf

@mikolajsp
Copy link
Collaborator

Hi,

Sadly both SVM and GBM models for survival data do not produce a distribution type prediction, they only predict a single value of "risk" score for an observation.

However, mlr3proba allows for composing the risk prediction of these models with a baseline hazard function estimated using the Kaplan-Meier estimator. There is an example for this in the "using mlr3proba with survex" vignette for a different model, but I've quickly checked if it works for these specific models and I think it does.

There are some problems with the SVM because it doesn't want to predict when there is only one observation (which is done in some explanations) but I made a workaround by changing the prediction function to duplicate this one row, make the prediction and then take the first row of the prediction.

Below is the example code, I create the explainers and plot the result of model_performance() calculation and a Ceteris Paribus explanation for the first observation. They seem to work, so I imagine the explainers are correct, and all other explanations will also work.

library(survex)
library(survivalsvm)
library(gbm)
library(mlr3proba)
library(mlr3extralearners)
library(mlr3pipelines)

## Create an mlr3 task -- used to fit the models
veteran_task <- as_task_surv(veteran,
                             time = "time",
                             event = "status",
                             type = "right")

## Create a gbm model composed with a kaplan-meier
## baseline hazard prediction
gbm_composite_learner <- as_learner(ppl(
    "distrcompositor",
    learner = lrn("surv.gbm"),
    estimator = "kaplan",
    form = "ph"
))

## Train the gbm model
gbm_composite_learner$train(veteran_task)

## Create the explainer, adding `LearnerSurv`to the class is 
## necessary for automatic explainer creation
class(gbm_composite_learner) <- c(class(gbm_composite_learner), "LearnerSurv") 
gbm_explainer <- explain(gbm_composite_learner,
                         data = veteran[, -c(3,4)],
                         y = Surv(veteran$time, veteran$status),
                         label = "Composite GBM model")

## Create a svm model composed with a kaplan-meier
## baseline hazard prediction
svm_composite_learner <- as_learner(ppl(
    "distrcompositor",
    learner = lrn("surv.svm", type = "vanbelle2", diff.meth = "makediff3", gamma.mu = 0.1),
    estimator = "kaplan",
    form = "ph"
))

## Train the gbm model
svm_composite_learner$train(veteran_task)

## Create the explainer, adding `LearnerSurv`to the class is 
## necessary for automatic explainer creation
class(svm_composite_learner) <- c(class(gbm_composite_learner), "LearnerSurv") 


### SVM does not seem to want to predict when there is only one observation
### I manually provide a custom function as a workaround
svm_predict <- function(model, newdata, times) {
    if (nrow(newdata) == 1){
        newdata <- rbind(newdata, newdata)
        t(model$predict_newdata(newdata)$distr$survival(times))[1, , drop=FALSE]
    }
    else{
        t(model$predict_newdata(newdata)$distr$survival(times))
    }
    
}

## Create the svm explainer
svm_explainer <- explain(svm_composite_learner,
                         data = veteran[, -c(3,4)],
                         y = Surv(veteran$time, veteran$status),
                         predict_survival_function = svm_predict,
                         label = "Composite SVM model")


## Calculate and compare performance using `survex` 
performance_svm <- model_performance(svm_explainer)
performance_gbm <- model_performance(gbm_explainer)

plot(performance_svm, performance_gbm)

image

## Calculate and plot Ceteris Paribus explanations
ice_gbm <- predict_profile(gbm_explainer, veteran[1, -c(3,4)])
plot(ice_gbm, variables = "age", numerical_plot_type = "contours")

image

ice_svm <- predict_profile(svm_explainer, veteran[1, -c(3,4)])
plot(ice_svm, variables = "age", numerical_plot_type = "contours")

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants