Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treatment of categorical features in potential_interactions(): suggestion to use R squared instead of squared correlation #119

Closed
RoelVerbelen opened this issue Oct 24, 2023 · 15 comments
Labels
enhancement New feature or request long-term

Comments

@RoelVerbelen
Copy link

Thanks for providing a great SHAP visualisation package for R!

I'm looking into fast ways to surface interaction effects in H2O GBMs. Unfortunately, unlike xgboost, H2O does not provide interaction SHAP values and hence shapviz relies on a heuristic based on weighted squared Pearson correlation between the SHAP value and other features' values in its potential_interactions() implementation. I think that's a reasonable approach, but it doesn't work well for unordered categorical features (where it converts them to their arbitrarily ordered factor level numbers using data.matrix()).

A natural extension of what you are doing now, which I believe would be more appropriate for categorical features, would be to consider the R squared of a linear regression model of the SHAP values on each of the other feature. For continuous features, that would give you the exact same value you have now. For categorical features, that would be measuring the association between the unordered factor levels and the SHAP values in a way that's not constraint by the arbitrary feature level numbering.

If you want to implement that, lines 230-233 would have to be replaced by:

  # Complicated case: we need to rely on R squared based heuristic
  r_sq <- function(s, x) {
    sapply(x, 
           function(x) {
             tryCatch({
               summary(stats::lm(s ~ x))$r.squared
             }, error = function(e) {
               return(NA)
             })
           })
  }

Here's a full example using a public H2O data set:

library(shapviz)
library(h2o)
h2o.init()

# Import the prostate dataset into H2O:
prostate <- h2o.importFile("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv")

# Set the predictors and response; set the factors:
prostate$CAPSULE <- as.factor(prostate$CAPSULE)
prostate$RACE <- as.factor(prostate$RACE)
prostate$DPROS <- as.factor(prostate$DPROS)
prostate$DCAPS <- as.factor(prostate$DCAPS)
prostate$GLEASON <- as.factor(prostate$GLEASON)
predictors <- c("AGE", "RACE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON")
response <- "CAPSULE"

# Build and train the model:
pros_gbm <- h2o.gbm(x = predictors,
                    y = response,
                    nfolds = 5,
                    seed = 1111,
                    keep_cross_validation_predictions = TRUE,
                    training_frame = prostate)

# Create shapviz object
shp <- shapviz(pros_gbm, X_pred = prostate, X = as.data.frame(prostate))

# Replace correlation measure with R squared measure
potential_interactions_rsq <- function(obj, v) {
  stopifnot(is.shapviz(obj))
  S <- get_shap_values(obj)
  S_inter <- get_shap_interactions(obj)
  X <- get_feature_values(obj)
  nms <- colnames(obj)
  v_other <- setdiff(nms, v)
  stopifnot(v %in% nms)
  
  if (ncol(obj) <= 1L) {
    return(NULL)
  }
  
  # Simple case: we have SHAP interaction values
  if (!is.null(S_inter)) {
    return(sort(2 * colMeans(abs(S_inter[, v, ]))[v_other], decreasing = TRUE))
  }
  
  # Complicated case: we need to rely on R squared based heuristic
  r_sq <- function(s, x) {
    sapply(x, 
           function(x) {
             tryCatch({
               summary(stats::lm(s ~ x))$r.squared
             }, error = function(e) {
               return(NA)
             })
           })
  }
  n_bins <- ceiling(min(sqrt(nrow(X)), nrow(X) / 20))
  v_bin <- shapviz:::.fast_bin(X[[v]], n_bins = n_bins)
  s_bin <- split(S[, v], v_bin)
  X_bin <- split(X[v_other], v_bin)
  w <- do.call(rbind, lapply(X_bin, function(z) colSums(!is.na(z))))
  cor_squared <- do.call(rbind, mapply(r_sq, s_bin, X_bin, SIMPLIFY = FALSE))
  sort(colSums(w * cor_squared, na.rm = TRUE) / colSums(w), decreasing = TRUE)
}

# Current implementation
potential_interactions(shp, v = "PSA")
#>    GLEASON      DPROS        VOL      DCAPS       RACE        AGE 
#> 0.14827267 0.10383619 0.07988404 0.07166984 0.06715848 0.05922560

# Suggested implementation
potential_interactions_rsq(shp, v = "PSA")
#>    GLEASON      DPROS        VOL       RACE      DCAPS        AGE 
#> 0.32998601 0.25517234 0.07988404 0.07827180 0.07166984 0.05922560

Created on 2023-10-24 with reprex v2.0.2

@mayer79 mayer79 self-assigned this Oct 24, 2023
@mayer79 mayer79 added the enhancement New feature or request label Oct 24, 2023
@mayer79 mayer79 removed their assignment Oct 24, 2023
@mayer79
Copy link
Collaborator

mayer79 commented Oct 24, 2023

Very good idea, thanks a lot for the suggestion and the implementation.

I will add it almost unchanged (X is always a data.frame, so sapply/vapply will always work):

r_sq1 <- function(s, x) {
  tryCatch(
    summary(stats::lm(s ~ x))[["r.squared"]],
    error = function(e) return(NA)
  )
}

r_sq <- function(s, x) {
  suppressWarnings(
    vapply(x, FUN = r_sq1, FUN.VALUE = numeric(1L), s = s, USE.NAMES = FALSE)
  )
}

r_sq(1:2, data.frame(x = c("A", "A"), y = c("A", "B")))

Without tryCatch(), above example would fail.

If you want to attempt a PR, it would even be better. You would need to add the change in the NEWS file.

@mayer79
Copy link
Collaborator

mayer79 commented Oct 24, 2023

One addition: We should probably work with adjusted R-Squared to have a fairer selection. The bins (in x) can be relatively small, then it seems unfair if we invest 1 df in the non-factors and >1 df in the factors. See the example:

r_sq1 <- function(s, x) {
  tryCatch(
    summary(stats::lm(s ~ x))[["adj.r.squared"]],
    error = function(e) return(NA)
  )
}

r_sq <- function(s, x) {
  suppressWarnings(
    vapply(x, FUN = r_sq1, FUN.VALUE = numeric(1L), s = s, USE.NAMES = FALSE)
  )
}

r_sq(1:2, data.frame(x = c("A", "A"), y = c("A", "B"))) # NA NaN
r_sq(1:3, data.frame(x = c(1, 2, 2), y = c("A", "B", "B"))) # 0.5 0.5

It will have an impact also on the numeric factors (but not on their order).

@RoelVerbelen
Copy link
Author

RoelVerbelen commented Oct 24, 2023

Yeah, I was thinking the same, but wanted to illustrate the equivalence with what you're doing now for numerical features first before addressing that imbalance in degrees of freedom.

I think it would be even better if we move away from a correlation / R squared measure to something that is comparable across different vs so we can come up with an interaction importance ranking. See my issue and suggestion in #120. Not sure how to take the varying degrees of freedom there into account:

# Complicated case: we need to rely on modelled variation based heuristic
  mod_var1 <- function(s, x) {
    tryCatch(
      mean(abs(stats::lm(s ~ x)$fitted - mean(s))), 
      error = function(e) return(NA)
    )
  }
  mod_var <- function(s, x) {
    suppressWarnings(
      vapply(x, FUN = mod_var1, FUN.VALUE = numeric(1L), s = s, USE.NAMES = FALSE)
    )
  }

@RoelVerbelen
Copy link
Author

RoelVerbelen commented Oct 24, 2023

I think moving to the model sum of squares, corrected for the residual degrees of freedom, would be the best solution.

mod_ss1 <- function(s, x) {
    tryCatch(
      {
        fit <- stats::lm(s ~ x)
        sum((fit$fitted - mean(s))^2) / fit$df.residual
      },
      error = function(e) return(NA)
    )
  }
  mod_ss <- function(s, x) {
    suppressWarnings(
      vapply(x, FUN = mod_var1, FUN.VALUE = numeric(1L), s = s, USE.NAMES = FALSE)
    )
  }

Or taking the square root of that if we want to bring it to the scale of the predictions and more in line with SHAP interaction values.

@mayer79
Copy link
Collaborator

mayer79 commented Oct 25, 2023

Thanks for the suggestions. I need to think about how this plays with the weighting regime over bins. For instance, we could, alternatively, use something like R2-adj(lm(s ~ x * factor(binned values of the variable on the x axis))) and forget about the weighted average of non-missing values.

@RoelVerbelen
Copy link
Author

I've thought about it a bit more and created a pull request with my suggestion.

I believe a good heuristic would be to look at the difference between the mean squared error (MSE) of the SHAP values of v regressed on the feature values of v' and the MSE of the null model. That corresponds with the numerator of $R^2_{adj}$. By not dividing by the denominator of the $R^2_{adj}$ (= mean sum of total squares = MSE of null model) we get a metric that has meaning across features v. Taking the square root at the end brings it back to the scale of the SHAP values and brings it in line with the interaction strength metric when SHAP interaction values are available.

Here's an example using the new function:

devtools::load_all(".")
library(tidyverse)
library(xgboost)
 
set.seed(3653)
x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price)
fit <- xgb.train(params = list(learning_rate = 0.1), data = dtrain, nrounds = 65)
 
# Explanation data
dia_small <- diamonds[sample(nrow(diamonds), 2000), ]
 
# shapviz object with interactions
shp_i <- shapviz(
  fit, X_pred = data.matrix(dia_small[x]), X = dia_small, interactions = TRUE
)
 
# shapviz object without interactions
shp <- shapviz(
  fit, X_pred = data.matrix(dia_small[x]), X = dia_small, interactions = FALSE
)

# Using SHAP interaction values
potential_interactions(shp_i, v = "cut")
#>   carat  clarity    color 
#> 98.98315 23.92846 17.94669 

# Using suggested heuristics 
potential_interactions(shp, v = "cut")
#>  carat  clarity    color 
#> 57.58020 21.24211 11.38107 

# Using SHAP interaction values
potential_interactions(shp_i, v = "carat")
#>  clarity     color       cut 
#> 600.07087 412.44253  98.98317 

# Using suggested heuristics 
potential_interactions(shp, v = "carat")
#> clarity    color      cut 
#> 349.4513 226.2636 193.9465 

# Function to create table with ranked interaction variables
table_potential_interactions <- function(predictor, shp) {
  pi <- potential_interactions(shp, predictor)
  tibble(var1 = predictor, var2 = names(pi), interaction_strength = pi)
}
 
# Interaction importance (SHAP interaction values)
map(x, table_potential_interactions, shp_i) %>% 
  bind_rows() %>% 
  arrange(desc(interaction_strength))
#> # A tibble: 12 × 3
#>   var1    var2    interaction_strength
#>   <chr>   <chr>                  <dbl>
#> 1 clarity carat                  600. 
#> 2 carat   clarity                600. 
#> 3 carat   color                  412. 
#> 4 color   carat                  412. 
#> 5 color   clarity                188. 
#> 6 clarity color                  188. 
#> 7 carat   cut                     99.0
#> 8 cut     carat                   99.0
#> 9 cut     clarity                 23.9
#>10 clarity cut                     23.9
#>11 color   cut                     17.9
#>12 cut     color                   17.9
 
# Interaction importance (heuristics)
map(x, table_potential_interactions, shp) %>% 
  bind_rows() %>% 
  arrange(desc(interaction_strength))
#> # A tibble: 12 × 3
#>    var1    var2    interaction_strength
#>    <chr>   <chr>                  <dbl>
#>  1 carat   clarity                349. 
#>  2 clarity carat                  306. 
#>  3 carat   color                  226. 
#>  4 carat   cut                    194. 
#>  5 color   carat                  177. 
#>  6 clarity color                  117. 
#>  7 color   clarity                 85.6
#>  8 clarity cut                     65.0
#>  9 cut     carat                   57.6
#> 10 cut     clarity                 21.2
#> 11 color   cut                     16.2
#> 12 cut     color                   11.4

@RoelVerbelen
Copy link
Author

R2-adj(lm(s ~ x * factor(binned values of the variable on the x axis))) would create a metric which is not comparable across vs. You need to scale it with the total variance (and that can vary by bin).

@mayer79
Copy link
Collaborator

mayer79 commented Oct 25, 2023

Loudly thinking:

  1. Within a specific bin, I would consider the R-squared adjusted a (quite) fair measure for the interaction strength. The R2-adj estimates the proportion of SHAP value variability explained by the color variable within the bin.
  2. The question is: How should the values be aggregated over bins? Currently, I am using the bin size as weight. A bin with small variability gets the same weight as a bin with high variability. I think your approach attacks this logic, but we need a very good reason to change it.
  3. To switch to lm() instead of using simple Pearson correlations is definitively smart.

@RoelVerbelen
Copy link
Author

Weighted averages of (adjusted) R-squared (or the current Pearson correlation) across the bins are not appropriate I believe as it does not take the amount of variation in the SHAP values within each bin into account (your point 2).

In the proposed new heuristic, we'd consider looking in each bin at:

$$ \frac{SSTO}{n-1} * R^2_{adj} = \frac{SSTO}{n-1} * \left(1 - \frac{\frac{SSE}{n-p}}{\frac{SSTO}{n-1}} \right) = \frac{SSTO}{n-1} - \frac{SSE}{n-p}$$

which I'd call the explained amount of variability (the different between the MSE of a null model and the MSE of the model regressed on v'). This also takes the degrees of freedom into account (to penalise categorical v's with many categories). The resulting metric reflects the amount of variability in the SHAP values. Loosely speaking, for large $n$, we get the regression sum of squares divided by the number of observations:

$$ \frac{SSTO}{n-1} - \frac{SSE}{n-p} \approx \frac{SSR}{n} = \frac{1}{n} \sum_{i=1}^{n} (\hat{y_i} - \bar{y})^2 $$

or the average squared linear SHAP value of that model. On this scale, aggregating by taking a weighted average over the bins makes sense in my opinion. Finally, taking the square root puts it back on the scale of the SHAP values and makes it comparable to the sum of the absolute values of the exact SHAP interaction values, see my numerical example in a previous comment.

@leopjj
Copy link

leopjj commented Nov 16, 2023

just like that

@mayer79
Copy link
Collaborator

mayer79 commented Dec 27, 2023

After thinking about this again, switching to a completely different regime is a too large change. Additionally, people often use factor variables with many categories, leading to perfect alignement within x bin (both on color feature and x feature).

We can, however, think of how to parametrize the logic in potential_interaction() and sv_dependence(). Arguments could be as follows (ih means interaction heuristic):

  • ih_nbins = NULL: Number of bins for numeric feature values on x-axis. If the number of unique values is not higher than that, no binning will be done. The default NULL uses the smaller of sqrt(n) and n/20, rounded up, where n is the number of observations.
  • ih_color_numeric = TRUE: Should non-numeric color feature values be turned into numeric values first? Default is TRUE.
  • ih_collapse_bins = c("bin_size", "your proposal"): How should R-squared be collapsed over bins? The default "wmean" uses a weighted mean. "your_proposal" is your approach.

In all cases, we would switch to R-squared adjusted, using the formula in stats.summary.lm().

ping @RoelVerbelen

@mayer79
Copy link
Collaborator

mayer79 commented Jan 1, 2024

Partly implemented in https://github.com/ModelOriented/shapviz/pulls

The status is like this:

  • We have now four arguments to play with the heuristic.
  • The defaults give the same results as before.
  • We can change the defaults later, when we have more experience. E.g., switch to R2 adjusted, don't convert non-numeric, smaller bin size, scale up the R-squared by the SHAP variance.
  • No square root

@mayer79 mayer79 closed this as completed Jan 1, 2024
@RoelVerbelen
Copy link
Author

Hi @mayer79, apologies for my late reply and thanks for adding so much flexibility to the function!

Exploring the new arguments, I've confirmed that the setting replicating my proposal above is this:

sqrt(shapviz::potential_interactions(shp, 'variable_name', color_num = FALSE, scale = TRUE, adjusted = TRUE))

And I agree limiting nbins (maybe to say 10) would be wise as well. I'll be using these settings myself going forward.

My only remaining suggestion for your consideration would be for nbins to also affect categorical variables within shapviz:::.fast_bin() (e.g. by using similar logic as forcats::fct_lump to group the small categories).

Thanks for implementing!

@mayer79
Copy link
Collaborator

mayer79 commented Jan 24, 2024

Nice! We need to play with these arguments.

Lumping small categories is a good idea, actually for both x variable and the color variable, right?

@RoelVerbelen
Copy link
Author

Yes, for both, to avoid splitting the data in too many bins and avoid regression using too many factor levels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request long-term
Projects
None yet
Development

No branches or pull requests

3 participants