Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

calculate_variable_profile coerces integers to numerics #145

Closed
simonpcouch opened this issue Dec 5, 2022 · 5 comments
Closed

calculate_variable_profile coerces integers to numerics #145

simonpcouch opened this issue Dec 5, 2022 · 5 comments
Assignees
Labels
bug 💣 Bug to fix

Comments

@simonpcouch
Copy link

The tidymodels team recently introduced support for finer-grained numeric classes in recipes. A user recently pointed on our community forums that this introduced issues with model_profile() in some cases. Here's a reprex:

library(tidymodels)
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.2).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain

ames_split <- initial_split(ames)
ames_train <- training(ames_split)

vip_features <- c("Neighborhood", "Gr_Liv_Area", "Year_Built", 
                  "Bldg_Type", "Latitude", "Longitude")

vip_train <- 
  ames_train %>% 
  select(all_of(vip_features))

rf_model <- 
  rand_forest(trees = 1000) %>% 
  set_engine("ranger") %>% 
  set_mode("regression")

rf_wflow <- 
  workflow() %>% 
  add_formula(
    Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + 
      Latitude + Longitude) %>% 
  add_model(rf_model) 

rf_fit <- fit(rf_wflow, ames_train)

explainer_rf <- 
  explain_tidymodels(
    rf_fit, 
    data = vip_train, 
    y = ames_train$Sale_Price,
    label = "random forest",
    verbose = FALSE
  )

model_profile(explainer_rf, N = 500, variables = "Year_Built")
#> Error in `scream()`:
#> ! Can't convert from `data$Year_Built` <double> to `Year_Built` <integer> due to loss of precision.
#> • Locations: 3, 13, 23, 72, 76, 86, 96, 145, 149, 159, 169, 218, 222, 232, 242,...

#> Backtrace:
#>      ▆
#>   1. └─DALEX::model_profile(explainer_rf, N = 500, variables = "Year_Built")
#>   2.   ├─ingredients::ceteris_paribus(...)
#>   3.   └─ingredients:::ceteris_paribus.explainer(...)
#>   4.     └─ingredients:::ceteris_paribus.default(...)
#>   5.       ├─ingredients:::calculate_variable_profile(...)
#>   6.       └─ingredients:::calculate_variable_profile.default(...)
#>   7.         └─base::lapply(...)
#>   8.           └─ingredients (local) FUN(X[[i]], ...)
#>   9.             ├─DALEX (local) predict_function(model, new_data, ...)
#>  10.             └─DALEXtra:::yhat.workflow(model, new_data, ...)
#>  11.               ├─stats::predict(X.model, newdata)
#>  12.               └─workflows:::predict.workflow(X.model, newdata)
#>  13.                 └─workflows:::forge_predictors(new_data, workflow)
#>  14.                   ├─hardhat::forge(new_data, blueprint = mold$blueprint)
#>  15.                   └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint)
#>  16.                     ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes)
#>  17.                     └─hardhat:::run_forge.default_formula_blueprint(...)
#>  18.                       └─hardhat:::forge_formula_default_clean(...)
#>  19.                         └─hardhat::scream(predictors, blueprint$ptypes$predictors, allow_novel_levels = blueprint$allow_novel_levels)

Created on 2022-12-05 with reprex v2.0.2

The issue arises here, where the numeric split_points are dropped into the (possibly) integer variable:

new_data[, variable] <- rep(split_points, nrow(data))

@hbaniecki hbaniecki added the bug 💣 Bug to fix label Dec 5, 2022
@hbaniecki hbaniecki self-assigned this Dec 5, 2022
@hbaniecki
Copy link
Member

hbaniecki commented Dec 5, 2022

Hi Simon, thanks for this report.

A quick workaround is to state variable_splits explicitly:

# ok
ingredients::ceteris_paribus(
  explainer_rf, 
  explainer_rf$data,
  variable_splits = list(Year_Built=unique(vip_train$Year_Built))
)

An error occurs due to the default calculate_variable_split()

# error
ingredients::ceteris_paribus(
  explainer_rf, 
  explainer_rf$data,
  variable_splits = ingredients:::calculate_variable_split.default(explainer_rf$data, variables=c("Year_Built"))
)

# float, not an integer
ingredients:::calculate_variable_split.default(explainer_rf$data, variables=c("Year_Built"))

Fixing this issue requires adding !is.integer(selected_column) to

if (is.numeric(selected_column)) {

which would lead to treating integer features like categorical features with unique().

# ok
ingredients::ceteris_paribus(
  explainer_rf, 
  explainer_rf$data,
  variable_splits = list(Year_Built=unique(vip_train$Year_Built))
)

@pbiecek what do you think?

@pbiecek
Copy link
Member

pbiecek commented Dec 5, 2022

Thanks for tracking down this tricky error!

Treating an integer as a categorical variable is a good idea, as long as it doesn't have too many different levels (e.g. someone has a column with an ids and there are 10000 different values in it, that would kill our profile calculation).
So maybe an extra condition in the if statement - if there is an integer variable and the number of different values is under 100 then treated as categorical (i.e. unique)
but if there are a lot of values it is converted to float?

@hbaniecki
Copy link
Member

I implemented the fix, and it actually still fails ungracefully in the above scenario, because there are 113 unique year values.

This got me thinking that with categorical variables, we don't have a threshold on how many unique values there should be.

We can either:

  1. Remove the auxiliary threshold and all integer variables will be treated as categorical. This removes the error, and users need to pay attention to the results / why it computes for so long.
  2. Set a threshold to the value of grid_points (=101 by default):
    1. only on integer variables. This will lead to the same uninformative error for the user.
    2. on both: integer and categorical variables, and then raise an informative error message for the user to increase grid_points when the threshold is reached. This breaks some previous code but improves the quality of the user's experience interacting with our API.

@pbiecek
Copy link
Member

pbiecek commented Jan 4, 2023

great idea,
Let's do 1 with additional warning if there is more than 201 unique values

hbaniecki added a commit that referenced this issue Jan 8, 2023
pbiecek pushed a commit that referenced this issue Jan 13, 2023
@hbaniecki
Copy link
Member

hbaniecki commented Jan 26, 2023

This is hopefully fixed now on CRAN

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 💣 Bug to fix
Projects
None yet
Development

No branches or pull requests

3 participants