Skip to content

Commit

Permalink
Merge pull request #129 from ModelOriented/edge-case-2
Browse files Browse the repository at this point in the history
More edge-cases (heuristic)
  • Loading branch information
mayer79 committed Jan 2, 2024
2 parents c2a0cfc + 0f13f78 commit a50fab0
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 44 deletions.
9 changes: 5 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
If no SHAP interaction values are available, by default, the color feature `v'` is selected by the heuristic `potential_interaction()`, which works as follows:

1. If the feature `v` (the on the x-axis) is numeric, it is binned into `nbins` bins.
2. Per bin, the SHAP values of `v` are regressed onto `v` and the R-squared is calculated.
3. The R-squared are averaged over bins, weighted by the bin size.
2. Per bin, the SHAP values of `v` are regressed onto `v'` and the R-squared is calculated. Rows with missing `v'` are discarded.
3. The R-squared are averaged over bins, weighted by the number of non-missing `v'` values.

This measures how much variability in the SHAP values is explained by `v'`, after accounting for `v`.
This measures how much variability in the SHAP values of `v` is explained by `v'`, after accounting for `v`.

We have introduced four parameters to control the heuristic. Their defaults are in line with the old behaviour.

Expand All @@ -35,7 +35,8 @@ We will continue to experiment with the defaults, which might change in the futu

## Other user-visible changes

- `sv_dependence()`: If `color_var = "auto"` (default) and no color feature seems to be relevant (SHAP interaction is `NULL`, or heuristic returns no positive value), there won't be any color scale.
- `sv_dependence()`: If `color_var = "auto"` (default) and no color feature seems to be relevant (SHAP interaction is `NULL`, or heuristic returns no positive value), there won't be any color scale. Furthermore, in some edge cases, a different
color feature might be selected.
- `mshapviz()` objects can now be rowbinded via `rbind()` or `+`. Implemented by [@jmaspons](https://github.com/jmaspons) in [#110](https://github.com/ModelOriented/shapviz/pull/110).
- `mshapviz()` is more strict when combining multiple "shapviz" objects. These now need to have identical column names, see [#114](https://github.com/ModelOriented/shapviz/pull/114).

Expand Down
26 changes: 16 additions & 10 deletions R/potential_interactions.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
#' Interaction Strength
#'
#' Returns vector of interaction strengths between variable `v` and all other variables,
#' see Details.
#' Returns a vector of interaction strengths between variable `v` and all other
#' variables, see Details.
#'
#' If SHAP interaction values are available, the interaction strength
#' between feature `v` and another feature `v'` is measured by twice their
#' mean absolute SHAP interaction values.
#'
#' Otherwise, we use a heuristic calculated as follows to calculate interaction strength
#' between `v` and each other "color" feature `v':
#' Otherwise, we use a heuristic calculated as follows:
#' 1. If `v` is numeric, it is binned into `nbins` bins.
#' 2. Per bin, the SHAP values of `v` are regressed onto `v`, and the R-squared
#' is calculated.
#' 3. The R-squared are averaged over bins, weighted by the bin size.
#' is calculated. Rows with missing `v'` are discarded.
#' 3. The R-squared are averaged over bins, weighted by the number of
#' non-missing `v'` values.
#'
#' This measures how much variability in the SHAP values of `v` is explained by `v'`,
#' after accounting for `v`.
#'
#' Set `scale = TRUE` to multiply the R-squared by the within-bin variance
#' of the SHAP values. This will put higher weight to bins with larger scatter.
Expand All @@ -22,6 +25,8 @@
#'
#' Finally, set `adjusted = TRUE` to use *adjusted* R-squared.
#'
#' The algorithm does not consider observations with missing `v'` values.
#'
#' @param obj An object of class "shapviz".
#' @param v Variable name to calculate potential SHAP interactions for.
#' @param nbins Into how many quantile bins should a numeric `v` be binned?
Expand Down Expand Up @@ -60,7 +65,7 @@ potential_interactions <- function(obj, v, nbins = NULL, color_num = TRUE,
nbins <- ceiling(min(sqrt(nrow(X)), nrow(X) / 20))
}
out <- vapply(
X[v_other], # data.frame is a list
X[v_other],
FUN = heuristic,
FUN.VALUE = 1.0,
s = S[, v],
Expand Down Expand Up @@ -108,15 +113,16 @@ heuristic <- function(color, s, bins, color_num, scale, adjusted) {
#'
#' @inheritParams heuristic
#' @returns
#' A (1x2) matrix with heuristic and number of observations.
#' A (1x2) matrix with the heuristic and the number of observations with non-missing
#' `v'`.
heuristic_in_bin <- function(color, s, scale = FALSE, adjusted = FALSE) {
ok <- !is.na(color)
color <- color[ok]
s <- s[ok]
n <- length(s)
var_s <- stats::var(s)
if (n < 2L || var_s < .Machine$double.eps || length(unique(color)) < 2L) {
return(cbind(stat = NA, n = n))
return(cbind(stat = 0, n = n))
}
z <- stats::lm(s ~ color)
var_r <- sum(z$residuals^2) / (if (adjusted) z$df.residual else n - 1)
Expand All @@ -125,7 +131,7 @@ heuristic_in_bin <- function(color, s, scale = FALSE, adjusted = FALSE) {
stat <- stat * var_s
}
if (!is.finite(stat)) {
stat <- NA
stat <- 0
}
cbind(stat = stat, n = n)
}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ dia_2000 <- diamonds[sample(nrow(diamonds), 2000), x]
shp <- shapviz(fit, X_pred = data.matrix(dia_2000), X = dia_2000)

sv_importance(shp, show_numbers = TRUE)
sv_dependence(shp, v = x)}
sv_dependence(shp, v = x)
```

![](man/figures/README-imp.svg)
Expand Down
17 changes: 11 additions & 6 deletions man/potential_interactions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 23 additions & 23 deletions tests/testthat/test-potential_interactions.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,84 +66,84 @@ test_that("heuristic_in_bin() returns R-squared", {
)
})

test_that("Failing heuristic_in_bin() returns NA", {
expect_equal(heuristic_in_bin(c(NA, NA), 1:2), cbind(stat = NA, n = 0))
test_that("Failing heuristic_in_bin() returns 0", {
expect_equal(heuristic_in_bin(c(NA, NA), 1:2), cbind(stat = 0, n = 0))
})

test_that("heuristic_in_bin() returns NA for constant response", {
test_that("heuristic_in_bin() returns 0 for constant response", {
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1)),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), scale = TRUE),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), adjust = TRUE),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
})

test_that("heuristic_in_bin() returns NA for constant color", {
test_that("heuristic_in_bin() returns 0 for constant color", {
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1)),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), scale = TRUE),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), adjust = TRUE),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
})

test_that("heuristic_in_bin() returns 0 if response and color are constant", {
z <- c(1, 1)
expect_equal(
heuristic_in_bin(color = z, s = z),
cbind(stat = NA, n = 2L)
cbind(stat = 0, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, scale = TRUE),
cbind(stat = NA, n = 2L)
cbind(stat = 0, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, adjust = TRUE),
cbind(stat = NA, n = 2L)
cbind(stat = 0, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 2L)
cbind(stat = 0, n = 2L)
)
})

test_that("heuristic_in_bin() returns NA for single obs", {
test_that("heuristic_in_bin() returns 0 for single obs", {
expect_equal(
heuristic_in_bin(color = 2, s = 2),
cbind(stat = NA, n = 1L)
cbind(stat = 0, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, scale = TRUE),
cbind(stat = NA, n = 1L)
cbind(stat = 0, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, adjust = TRUE),
cbind(stat = NA, n = 1L)
cbind(stat = 0, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 1L)
cbind(stat = 0, n = 1L)
)
})

Expand All @@ -159,11 +159,11 @@ test_that("heuristic_in_bin() returns NA for single obs", {
)
expect_equal(
heuristic_in_bin(color = cc, s = 1:3, adjust = TRUE),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
expect_equal(
heuristic_in_bin(color = cc, s = 2*(1:3), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
cbind(stat = 0, n = 3L)
)
})

Expand Down

0 comments on commit a50fab0

Please sign in to comment.