Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stricter mshapviz interface #114

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# shapviz 0.9.3

## User-visible changes

- `mshapviz()` is more strict when combining multiple "shapviz" objects. These now need to have identical column names.

## Other changes

- Re-activate all unit tests.
Expand Down
17 changes: 14 additions & 3 deletions R/shapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,9 @@ shapviz.H2OModel = function(object, X_pred, X = as.data.frame(X_pred),
)
}

#' Concatenates "shapviz" Objects
#' Combines compatible "shapviz" Objects
#'
#' This function combines a list of "shapviz" objects to an object of class
#' This function combines a list of compatible "shapviz" objects to an object of class
#' "mshapviz". The elements can be named.
#'
#' @param object List of "shapviz" objects to be concatenated.
Expand All @@ -479,10 +479,21 @@ shapviz.H2OModel = function(object, X_pred, X = as.data.frame(X_pred),
#' s <- mshapviz(c(shp1 = s1, shp2 = s2))
#' s
mshapviz <- function(object, ...) {
stopifnot("'object' must be a list of 'shapviz' objects" = is.list(object))
stopifnot("'object' must be a list" = is.list(object))
if (!all(vapply(object, is.shapviz, FUN.VALUE = logical(1)))) {
stop("Must pass list of 'shapviz' objects")
}
nms <- lapply(object, colnames)
if (!all(vapply(nms, identical, y = nms[[1L]], FUN.VALUE = logical(1)))) {
stop("'shapviz' objects need to have identical column names")
}
# Plot methods using interactions and do.call(rbind, ...) will fail, other plots are ok
# inter <- vapply(
# object, function(x) is.null(get_shap_interactions(x)), FUN.VALUE = logical(1)
# )
# if (!(all(inter) || !any(inter))) {
# stop("Some 'shapviz' objects have SHAP interactions, some not.")
# }
structure(object, class = "mshapviz")
}

Expand Down
4 changes: 2 additions & 2 deletions man/mshapviz.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 24 additions & 6 deletions tests/testthat/test-interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ test_that("get_* functions work", {
expect_equal(4, get_baseline(mshp)[[1L]])
expect_equal(S, get_shap_values(mshp)[[1L]])
expect_equal(X, get_feature_values(mshp)[[1L]])

expect_error(get_baseline(3))
expect_error(get_shap_values("a"))
expect_error(get_feature_values(c(3, 9)))
})

test_that("dim, nrow, ncol, colnames work", {
Expand Down Expand Up @@ -176,6 +180,8 @@ mshp_inter <- c(shp1 = shp_inter, shp2 = shp_inter + shp_inter)

test_that("get_shap_interactions, +, rbind works for interactions", {
expect_equal(S_inter, get_shap_interactions(shp_inter))
expect_equal(length(get_shap_interactions(mshp_inter)), 2L)
expect_error(get_shap_interactions(4))
expect_equal(dim((shp_inter + shp_inter)$S_inter)[1L], 2 * dim(shp_inter$S_inter)[1L])
expect_equal(
dim(rbind(shp_inter, shp_inter, shp_inter)$S_inter)[1L],
Expand Down Expand Up @@ -211,7 +217,24 @@ test_that("mshapviz object contains original shapviz objects", {
expect_equal(mshp_inter[[2L]][1:nrow(shp_inter)], shp_inter)
})

# # Multiclass with XGBoost
test_that("shapviz objects with interactions can be rowbinded", {
expect_equal(dim(rbind(shp_inter, shp_inter)), dim(shp_inter) * (2:1))
expect_error(rbind(shp_inter, shp))
})

# Check on mshapviz
test_that("combining non-shapviz objects fails", {
expect_error(c(shp, 1))
expect_error(mshapviz(list(1, 2)))
})

test_that("combining incompatible shapviz objects fails", {
shp2 <- shp[, "x"]
expect_error(mshapviz(list(shp, shp2)))
expect_error(c(shp, shp2))
})

# Multiclass with XGBoost
X_pred <- data.matrix(iris[, -5L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = as.integer(iris[, 5L]) - 1L)
fit <- xgboost::xgb.train(
Expand Down Expand Up @@ -242,8 +265,3 @@ test_that("combining shapviz on classes 1, 2, 3 equal mshapviz", {
expect_equal(mshp, mshapviz(list(Class_1 = shp1, Class_2 = shp2, Class_3 = shp3)))
})

test_that("combining non-shapviz objects fails", {
expect_error(c(shp3, 1))
expect_error(mshapviz(1, 2))
})

13 changes: 13 additions & 0 deletions tests/testthat/test-plots-mshapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ test_that("dependence plots work for interactions = TRUE", {
)
})

test_that("shapviz objects w/o interactions can be combined and used for most things", {
x_temp1 <- shapviz(fit, X_pred = dtrain, X = iris[, -1L], interactions = TRUE)
x_temp2 <- shapviz(fit, X_pred = dtrain, X = iris[, -1L])
expect_no_error(x_inter_m <- c(x_temp1, x_temp2))
expect_error(do.call(rbind, x_inter_m))
expect_error(sv_interaction(x_inter_m))
expect_s3_class(sv_importance(x_inter_m), "ggplot")
expect_s3_class(sv_dependence(x_inter_m, "Sepal.Width"), "patchwork")
expect_error(sv_dependence(x_inter_m, "Sepal.Width", interactions = TRUE))
expect_equal(sapply(get_shap_interactions(x_inter_m), is.null), c(FALSE, TRUE))
})

test_that("main effect plots equal case color_var = v", {
expect_equal(
sv_dependence(x_inter, "Petal.Length", color_var = NULL, interactions = TRUE),
Expand Down Expand Up @@ -118,3 +130,4 @@ test_that("sv_dependence() does not work with multiple v", {
expect_error(sv_dependence2D(x, x = c("Species", "Sepal.Width"), y = "Petal.Width"))
expect_error(sv_dependence2D(x, x = "Petal.Width", y = c("Species", "Sepal.Width")))
})