Skip to content

Commit

Permalink
Merge pull request #33 from ModelOriented/feat_multioutputs
Browse files Browse the repository at this point in the history
Add interface for multi-output models
  • Loading branch information
krzyzinskim committed Oct 24, 2023
2 parents cec10f9 + a74a357 commit 747b15e
Show file tree
Hide file tree
Showing 24 changed files with 258 additions and 44 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ src/*.o
src/*.so
src/*.dll
CRAN-SUBMISSION
.DS_Store
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: treeshap
Title: Compute SHAP Values for Your Tree-Based Models Using the 'TreeSHAP'
Algorithm
Version: 0.2.5.9001
Version: 0.3.0
Authors@R: c(
person("Konrad", "Komisarczyk", email = "komisarczykkonrad@gmail.com", role = "aut"),
person("Pawel", "Kozminski", email = "pkozminski99@gmail.com", role = "aut"),
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

S3method(predict,model_unified)
S3method(print,model_unified)
S3method(print,model_unified_multioutput)
S3method(print,treeshap)
S3method(print,treeshap_multioutput)
S3method(treeshap,model_unified)
S3method(treeshap,model_unified_multioutput)
S3method(unify,default)
S3method(unify,gbm)
S3method(unify,lgb.Booster)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# treeshap (development version)
* Fixed `ranger_surv.unify` operation for predictions in form of survival and cumulative hazard functions.
* Added `model_unified_multioutput` and `treeshap_multioutput` classes for multi-output models and their explanations.
* Improved documentation of `ranger_surv.unify`.

# treeshap 0.2.5
* Removed `catboost.unify` function (as the `catboost` package is not on CRAN); it is available on a separate branch
Expand Down
45 changes: 37 additions & 8 deletions R/model_unified.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Unified model representation
#'
#' \code{model_unified} object produced by \code{*.unify} function.
#' \code{model_unified} object produced by \code{*.unify} or \code{unify} function.
#'
#' @return List consisting of two elements:
#'
Expand All @@ -18,7 +18,7 @@
#' \item{Prediction}{For leaves: Value of prediction in the leaf. For internal nodes: NA}
#' \item{Cover}{Number of observations seen by the internal node or collected by the leaf for the reference dataset}
#'
#' \strong{data} - Dataset used as a reference for calculating SHAP values. A dataset passed to the \code{*.unify} or \code{\link{set_reference_dataset}} function with \code{data} argument. A \code{data.frame}.
#' \strong{data} - Dataset used as a reference for calculating SHAP values. A dataset passed to the \code{*.unify}, \code{unify} or \code{\link{set_reference_dataset}} function with \code{data} argument. A \code{data.frame}.
#'
#'
#' Object has two also attributes set:
Expand All @@ -27,21 +27,29 @@
#'
#'
#' @seealso
#' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}}
#' \code{\link{unify}}
#'
#' \code{\link{gbm.unify}} for \code{\link[gbm:gbm]{GBM models}}
#'
#' \code{\link{xgboost.unify}} for \code{\link[xgboost:xgboost]{XGBoost models}}
#' @name model_unified.object
#'
#' \code{\link{ranger.unify}} for \code{\link[ranger:ranger]{ranger models}}
NULL


#' Unified model representations for multi-output model
#'
#' \code{\link{randomForest.unify}} for \code{\link[randomForest:randomForest]{randomForest models}}
#' \code{model_unified_multioutput} object produced by \code{*.unify} or \code{unify} function.
#'
#' @return List consisting of \code{model_unified} objects, one for each individual output of a model. For survival models, the list is named using the time points, for which predictions are calculated.
#'
#' @name model_unified.object
#' @seealso
#' \code{\link{unify}}
#'
#'
#' @name model_unified_multioutput.object
#'
NULL


#' Prints model_unified objects
#'
#' @param x a model_unified object
Expand All @@ -56,6 +64,27 @@ print.model_unified <- function(x, ...){
return(invisible(NULL))
}


#' Prints model_unified_multioutput objects
#'
#' @param x a model_unified_multioutput object
#' @param ... other arguments
#'
#' @return No return value, called for printing
#'
#' @export
#'
print.model_unified_multioutput <- function(x, ...){
output_names <- names(x)
lapply(output_names, function(output_name){
cat(paste("-> for output:", output_name, "\n"))
print(x[[output_name]])
cat("\n")
})
return(invisible(NULL))
}


#' Check whether object is a valid model_unified object
#'
#' Does not check correctness of representation, only basic checks
Expand Down
59 changes: 57 additions & 2 deletions R/treeshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @param interactions Whether to calculate SHAP interaction values. By default is \code{FALSE}. Basic SHAP values are always calculated.
#' @param verbose Whether to print progress bar to the console. Should be logical. Progress bar will not be displayed on Windows.
#'
#' @return A \code{\link{treeshap.object}} object. SHAP values can be accessed with \code{$shaps}. Interaction values can be accessed with \code{$interactions}.
#' @return A \code{\link{treeshap.object}} object (for single-output models) or \code{\link{treeshap_multioutput.object}}, which is a list of \code{\link{treeshap.object}} objects (for multi-output models). SHAP values can be accessed from \code{\link{treeshap.object}} with \code{$shaps}, and interaction values can be accessed with \code{$interactions}.
#'
#'
#' @export
Expand Down Expand Up @@ -54,8 +54,12 @@
#' treeshap2$interactions
#' }
treeshap <- function(unified_model, x, interactions = FALSE, verbose = TRUE) {
model <- unified_model$model
UseMethod("treeshap", unified_model)
}

#' @export
treeshap.model_unified <- function(unified_model, x, interactions = FALSE, verbose = TRUE){
model <- unified_model$model
# argument check
if (!("matrix" %in% class(x) | "data.frame" %in% class(x))) {
stop("x parameter has to be data.frame or matrix.")
Expand Down Expand Up @@ -125,6 +129,19 @@ treeshap <- function(unified_model, x, interactions = FALSE, verbose = TRUE) {
return(treeshap_obj)
}


#' @export
treeshap.model_unified_multioutput <- function(unified_model, x, interactions = FALSE, verbose = TRUE){
treeshaps_objects <- lapply(unified_model,
treeshap.model_unified,
x = x,
interactions = interactions,
verbose = verbose)
class(treeshaps_objects) <- "treeshap_multioutput"
return(treeshaps_objects)
}


#' treeshap results
#'
#' \code{treeshap} object produced by \code{treeshap} function.
Expand All @@ -148,6 +165,23 @@ treeshap <- function(unified_model, x, interactions = FALSE, verbose = TRUE) {
NULL


#' treeshap results for multi-output model
#'
#' \code{treeshap_multioutput} object produced by \code{treeshap} function.
#'
#' @return List consisting of \code{treeshap} objects, one for each individual output of a model. For survival models, the list is named using the time points, for which TreeSHAP values are calculated.
#'
#'
#' @seealso
#' \code{\link{treeshap}},
#'
#' \code{\link{treeshap.object}}
#'
#'
#' @name treeshap_multioutput.object
NULL


#' Prints treeshap objects
#'
#' @param x a treeshap object
Expand All @@ -165,6 +199,27 @@ print.treeshap <- function(x, ...){
return(invisible(NULL))
}


#' Prints treeshap_multioutput objects
#'
#' @param x a treeshap_multioutput object
#' @param ... other arguments
#'
#' @return No return value, called for printing
#'
#' @export
#'
print.treeshap_multioutput <- function(x, ...){
output_names <- names(x)
lapply(output_names, function(output_name){
cat(paste("-> for output:", output_name, "\n"))
print(x[[output_name]])
cat("\n")
})
return(invisible(NULL))
}


#' Check whether object is a valid treeshap object
#'
#' Does not check correctness of result, only basic checks
Expand Down
3 changes: 2 additions & 1 deletion R/unify.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.
#' @param ... Additional parameters passed to the model-specific unification functions.
#'
#' @return A unified model representation - a \code{\link{model_unified.object}} object
#' @return A unified model representation - a \code{\link{model_unified.object}} object (for single-output models) or \code{\link{model_unified_multioutput.object}}, which is a list of \code{\link{model_unified.object}} objects (for multi-output models).
#'
#'
#' @seealso
#' \code{\link{lightgbm.unify}} for \code{\link[lightgbm:lightgbm]{LightGBM models}}
Expand Down
24 changes: 17 additions & 7 deletions R/unify_ranger_surv.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,27 @@
#' Convert your ranger model into a standardized representation.
#' The returned representation is easy to be interpreted by the user and ready to be used as an argument in \code{treeshap()} function.
#'
#' @details
#' The survival forest implemented in the \code{ranger} package stores cumulative hazard
#' functions (CHFs) in the leaves of survival trees, as proposed for Random Survival Forests
#' (Ishwaran et al. 2008). The final model prediction is made by averaging these CHFs
#' from all the trees. To provide explanations in the form of a survival function,
#' the CHFs from the leaves are converted into survival functions (SFs) using
#' the formula SF(t) = exp{-CHF(t)}.
#' However, it is important to note that averaging these SFs does not yield the correct
#' model prediction as the model prediction is the average of CHFs transformed in the same way.
#' Therefore, when you obtain explanations based on the survival function,
#' they are only proxies and may not be fully consistent with the model predictions
#' obtained using for example \code{predict} function.
#'
#
#' @param rf_model An object of \code{ranger} class. At the moment, models built on data with categorical features
#' are not supported - please encode them before training.
#' @param data Reference dataset. A \code{data.frame} or \code{matrix} with the same columns as in the training set of the model. Usually dataset used to train model.
#' @param type A character to define the type of model prediction to use. Either `"risk"` (default), which uses the risk score calculated as a sum of cumulative hazard function values, `"survival"`, which uses the survival probability at certain time-points for each observation, or `"chf"`, which used the cumulative hazard values at certain time-points for each observation.
#' @param times A numeric vector of unique death times at which the prediction should be evaluated. By default `unique.death.times` from model are used.
#'
#' @return For `type = "risk"` a unified model representation is returned - a \code{\link{model_unified.object}} object.
#' For `type = "survival"` or `type = "chf"` a list is returned that contains unified model representation
#' (\code{\link{model_unified.object}} objects) for each time point. In this case, the list names are the
#' `unique.death.times` (from the `ranger` object), at which the survival function was evaluated.
#' @return For `type = "risk"` a unified model representation is returned - a \code{\link{model_unified.object}} object. For `type = "survival"` or `type = "chf"` - a \code{\link{model_unified_multioutput.object}} object is returned, whic is a list that contains unified model representation (\code{\link{model_unified.object}} object) for each time point. In this case, the list names are time points at which the survival function was evaluated.
#'
#' @import data.table
#' @importFrom stats stepfun
Expand Down Expand Up @@ -63,9 +74,7 @@
#'
#' # compute shaps for 3 selected time points
#' unified_model_surv <- ranger_surv.unify(rf, train_x, type = "survival", times = c(23, 50, 73))
#' for (i in 1:3) {
#' shaps <- treeshap(unified_model_surv[[i]], train_x[1:2,])
#' }
#' shaps_surv <- treeshap(unified_model_surv, train_x[1:2,])
#'
ranger_surv.unify <- function(rf_model, data, type = c("risk", "survival", "chf"), times = NULL) {
type <- match.arg(type)
Expand Down Expand Up @@ -123,6 +132,7 @@ ranger_surv.unify <- function(rf_model, data, type = c("risk", "survival", "chf"
ranger_unify.common(x = x, n = n, data = data, feature_names = rf_model$forest$independent.variable.names)
})
names(unified_return) <- eval_times
class(unified_return) <- "model_unified_multioutput"
}
return(unified_return)
}
Expand Down
3 changes: 2 additions & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ set.seed(21)
# treeshap

<!-- badges: start -->

[![R-CMD-check](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml/badge.svg)](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml)
[![CRAN status](https://www.r-pkg.org/badges/version/treeshap)](https://CRAN.R-project.org/package=treeshap)
<!-- badges: end -->

In the era of complicated classifiers conquering their market, sometimes even the authors of algorithms do not know the exact manner of building a tree ensemble model. The difficulties in models' structures are one of the reasons why most users use them simply like black-boxes. But, how can they know whether the prediction made by the model is reasonable? `treeshap` is an efficient answer for this question. Due to implementing an optimized algorithm for tree ensemble models (called TreeSHAP), it calculates the SHAP values in polynomial (instead of exponential) time. Currently, `treeshap` supports models produced with `xgboost`, `lightgbm`, `gbm`, `ranger`, and `randomForest` packages. Support for `catboost` is available only in [`catboost` branch](https://github.com/ModelOriented/treeshap/tree/catboost) (see why [here](#catboost)).
Expand Down
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
<!-- badges: start -->

[![R-CMD-check](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml/badge.svg)](https://github.com/ModelOriented/treeshap/actions/workflows/CRAN-R-CMD-check.yaml)
[![CRAN status](https://www.r-pkg.org/badges/version/treeshap)](https://cran.r-project.org/package=treeshap)


[![CRAN
status](https://www.r-pkg.org/badges/version/treeshap)](https://CRAN.R-project.org/package=treeshap)
<!-- badges: end -->

In the era of complicated classifiers conquering their market, sometimes
Expand Down
2 changes: 2 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ reference:
- contents:
- treeshap
- treeshap.object
- treeshap_multioutput.object
- title: Unifiers
desc: Convert your model into a standardized representation
- contents:
- unify
- ends_with(".unify")
- model_unified.object
- model_unified_multioutput.object
- title: Plotting functions
desc: Plot explanation results
- contents:
Expand Down
14 changes: 3 additions & 11 deletions man/model_unified.object.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions man/model_unified_multioutput.object.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions man/print.model_unified_multioutput.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions man/print.treeshap_multioutput.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 747b15e

Please sign in to comment.