Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report confidence interval for permutation feature importance #326

Merged
merged 37 commits into from Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9c9c517
Calc lower & upper CI for feature importance (resolves #324)
kelly-sovacool Jan 24, 2023
28ee23c
No need for `usethis::` with `library(usethis)`
kelly-sovacool Jan 25, 2023
28b3a26
Document lower_bound() & upper_bound() together
kelly-sovacool Jan 25, 2023
ecc59cb
Update feature importance tests for Conf. intervals
kelly-sovacool Jan 25, 2023
f52db48
Recreate feature importance data with new conf. intervals
kelly-sovacool Jan 25, 2023
2053e8d
Rerun data w/ new feature importance conf. intervals
kelly-sovacool Jan 25, 2023
0cf1696
Don't run examples for internal functions
kelly-sovacool Jan 27, 2023
02fddc3
Document `alpha` parameter
kelly-sovacool Jan 27, 2023
c4d7a78
Merge 02fddc3ef915239fd3be5ae0434d46ee335c1cc8 into f73f8a84d6ff57a92…
kelly-sovacool Jan 27, 2023
9cb8da5
🎨 Style R code
github-actions[bot] Jan 27, 2023
cfcd292
📑 Build docs site
github-actions[bot] Jan 27, 2023
f7b3d35
Silence caret::train() warning about setting rownames on a tibble
kelly-sovacool Jan 27, 2023
f68d086
Use `all_of(var)` instead of `.data$var`
kelly-sovacool Jan 27, 2023
26b2104
Rename `names` column to `feat` for feat imp
kelly-sovacool Jan 27, 2023
aa2aa91
Rerun feature importance with new column name
kelly-sovacool Jan 27, 2023
72a2a01
Update vignettes with new feature importance cols
kelly-sovacool Jan 27, 2023
bb3242f
document()
kelly-sovacool Jan 27, 2023
06881ff
Merge branch 'iss-324' of https://github.com/SchlossLab/mikropml into…
kelly-sovacool Jan 27, 2023
4cad415
Update NEWS w/ feature importance improvements
kelly-sovacool Jan 27, 2023
5abd476
Merge 4cad41515b8c84ca7700fd85ea111de7ebce476d into f73f8a84d6ff57a92…
kelly-sovacool Jan 27, 2023
a40c08c
🎨 Style R code
github-actions[bot] Jan 27, 2023
d762191
Improve feature importance plots
kelly-sovacool Jan 27, 2023
8d8c4ba
Merge branch 'iss-324' of https://github.com/SchlossLab/mikropml into…
kelly-sovacool Jan 27, 2023
a270560
Merge 8d8c4ba2a52a320813e3ce7bf3a15526e7b964bf into f73f8a84d6ff57a92…
kelly-sovacool Jan 27, 2023
118be2c
🎨 Style R code
github-actions[bot] Jan 27, 2023
69ba636
pkgdown::build_article('parallel')
kelly-sovacool Jan 28, 2023
53466d5
Merge 69ba636de157f54e25eac473fdf678ee7b7302f0 into f73f8a84d6ff57a92…
kelly-sovacool Jan 28, 2023
3f34ace
Add {forcats} to suggests for parallel vignette
kelly-sovacool Jan 28, 2023
07620fd
Merge branch 'iss-324' of https://github.com/SchlossLab/mikropml into…
kelly-sovacool Jan 28, 2023
350315d
Merge 07620fd4c07b39ff9c367fb3acc774bb7f281ae1 into f73f8a84d6ff57a92…
kelly-sovacool Jan 28, 2023
d2007c1
📄 Render README.Rmd
github-actions[bot] Jan 28, 2023
ab38c01
📑 Build docs site
github-actions[bot] Jan 28, 2023
8817336
Merge branch 'main' into iss-324
kelly-sovacool Feb 1, 2023
387b74d
Merge 88173366bba5aee62d90058b3c8e56efd60d5ce6 into 2a0ec6319b4286393…
kelly-sovacool Feb 1, 2023
b002fa1
📚 Render Roxygen documentation
github-actions[bot] Feb 1, 2023
717a8cc
📄 Render README.Rmd
github-actions[bot] Feb 1, 2023
d192c41
📑 Build docs site
github-actions[bot] Feb 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Expand Up @@ -76,6 +76,7 @@ Imports:
Suggests:
assertthat,
doFuture,
forcats,
foreach,
future,
future.apply,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Expand Up @@ -3,6 +3,10 @@
- New function `bootstrap_performance()` allows you to calculate confidence
intervals for the model performance from a single train/test split by
bootstrapping the test set (#329, @kelly-sovacool).
- Improved output from `find_feature_importance()` (#326, @kelly-sovacool).
- Renamed the column `names` to `feat` to represent each feature or group of correlated features.
- New column `lower` and `upper` to report the bounds of the empirical 95% confidence interval from the permutation test.
See `vignette('parallel')` for an example of plotting feature importance with confidence intervals.
- Minor documentation improvements (#323, @kelly-sovacool).

# mikropml 1.5.0
Expand Down
47 changes: 42 additions & 5 deletions R/feature_importance.R
Expand Up @@ -13,7 +13,7 @@
#' grouped together based on `corr_thresh`.
#'
#' @return Data frame with performance metrics for when each feature (or group
#' of correlated features; `names`) is permuted (`perf_metric`), differences
#' of correlated features; `feat`) is permuted (`perf_metric`), differences
#' between the actual test performance metric on and the permuted performance
#' metric (`perf_metric_diff`; test minus permuted performance), and the
#' p-value (`pvalue`: the probability of obtaining the actual performance
Expand Down Expand Up @@ -173,7 +173,7 @@ get_feature_importance <- function(trained_model, test_data,

return(as.data.frame(imps) %>%
dplyr::mutate(
names = factor(groups),
feat = factor(groups),
method = method,
perf_metric_name = perf_metric_name,
seed = seed
Expand All @@ -191,10 +191,13 @@ get_feature_importance <- function(trained_model, test_data,
#' @param progbar optional progress bar (default: `NULL`)
#' @inheritParams run_ml
#' @inheritParams get_feature_importance
#' @param alpha alpha level for the confidence interval
#' (default: `0.05` to obtain a 95% confidence interval)
#'
#' @return vector of mean permuted performance and mean difference between test
#' and permuted performance (test minus permuted performance)
#' @noRd
#' @keywords internal
#'
#' @author Begüm Topçuoğlu, \email{topcuoglu.begum@@gmail.com}
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
Expand All @@ -203,7 +206,9 @@ find_permuted_perf_metric <- function(test_data, trained_model, outcome_colname,
perf_metric_function, perf_metric_name,
class_probs, feat,
test_perf_value,
nperms = 100, progbar = NULL) {
nperms = 100,
alpha = 0.05,
progbar = NULL) {
# The code below uses a bunch of base R subsetting that doesn't work with tibbles.
# We should probably refactor those to use tidyverse functions instead,
# but for now this is a temporary fix.
Expand Down Expand Up @@ -235,6 +240,38 @@ find_permuted_perf_metric <- function(test_data, trained_model, outcome_colname,
return(c(
perf_metric = mean_perm_perf,
perf_metric_diff = test_perf_value - mean_perm_perf,
pvalue = calc_pvalue(perm_perfs, test_perf_value)
pvalue = calc_pvalue(perm_perfs, test_perf_value),
lower = lower_bound(perm_perfs, alpha),
upper = upper_bound(perm_perfs, alpha)
))
}

#' @describeIn bounds Get the lower bound for an empirical confidence interval
#' @keywords internal
lower_bound <- function(x, alpha) {
x <- sort(x)
return(x[length(x) * alpha / 2])
}

#' @describeIn bounds Get the upper bound for an empirical confidence interval
#' @keywords internal
upper_bound <- function(x, alpha) {
x <- sort(x)
return(x[length(x) - length(x) * alpha / 2])
}

#' @name bounds
#' @title Get the lower and upper bounds for an empirical confidence interval
#'
#' @param x vector of test statistics, such as from permutation tests or bootstraps
#' @inheritParams find_permuted_perf_metric
#'
#' @return the value of the lower or upper bound for the confidence interval
#'
#' @examples
#' \dontrun{
#' x <- 1:10000
#' lower_bound(x, 0.05)
#' upper_bound(x, 0.05)
#' }
NULL
6 changes: 5 additions & 1 deletion R/run_ml.R
Expand Up @@ -174,7 +174,11 @@ run_ml <-
check_cat_feats(dataset %>% dplyr::select(-outcome_colname))
}

dataset <- randomize_feature_order(dataset, outcome_colname)
dataset <- dataset %>%
randomize_feature_order(outcome_colname) %>%
# convert tibble to dataframe to silence warning from caret::train():
# "Warning: Setting row names on a tibble is deprecated.."
as.data.frame()

outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname)

Expand Down
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -56,9 +56,9 @@ mamba install -c conda-forge r-mikropml

- Imports: caret, dplyr, e1071, glmnet, kernlab, MLmetrics,
randomForest, rlang, rpart, stats, utils, xgboost
- Suggests: assertthat, doFuture, foreach, future, future.apply, furrr,
ggplot2, knitr, progress, progressr, purrr, rmarkdown, rsample,
testthat, tidyr
- Suggests: assertthat, doFuture, forcats, foreach, future,
future.apply, furrr, ggplot2, knitr, progress, progressr, purrr,
rmarkdown, rsample, testthat, tidyr

## Usage

Expand Down
12 changes: 6 additions & 6 deletions data-raw/otu_mini_bin.R
Expand Up @@ -4,7 +4,7 @@ library(usethis)

## code to prepare `otu_mini` dataset
otu_mini_bin <- otu_small[, 1:11]
usethis::use_data(otu_mini_bin, overwrite = TRUE)
use_data(otu_mini_bin, overwrite = TRUE)

otu_data_preproc <- preprocess_data(otu_mini_bin, "dx")
use_data(otu_data_preproc)
Expand Down Expand Up @@ -39,7 +39,7 @@ otu_mini_bin_results_glmnet <- mikropml::run_ml(otu_mini_bin, # use built-in hyp
seed = 2019,
cv_times = 2
)
usethis::use_data(otu_mini_bin_results_glmnet, overwrite = TRUE)
use_data(otu_mini_bin_results_glmnet, overwrite = TRUE)

# cv_group <- sample(LETTERS[1:5], nrow(otu_mini_bin_results_glmnet$trained_model$trainingData), replace = TRUE)
cv_group <- c(
Expand Down Expand Up @@ -71,7 +71,7 @@ otu_mini_cv <- define_cv(otu_mini_bin_results_glmnet$trained_model$trainingData,
cv_times = 2,
groups = cv_group
)
usethis::use_data(otu_mini_cv, overwrite = TRUE)
use_data(otu_mini_cv, overwrite = TRUE)

# use built-in hyperparams function for this one
otu_mini_bin_results_rf <- mikropml::run_ml(otu_mini_bin,
Expand All @@ -82,7 +82,7 @@ otu_mini_bin_results_rf <- mikropml::run_ml(otu_mini_bin,
cv_times = 2,
groups = otu_mini_group
)
usethis::use_data(otu_mini_bin_results_rf, overwrite = TRUE)
use_data(otu_mini_bin_results_rf, overwrite = TRUE)

otu_mini_bin_results_svmRadial <- mikropml::run_ml(otu_mini_bin,
"svmRadial",
Expand All @@ -91,7 +91,7 @@ otu_mini_bin_results_svmRadial <- mikropml::run_ml(otu_mini_bin,
seed = 2019,
cv_times = 2
)
usethis::use_data(otu_mini_bin_results_svmRadial, overwrite = TRUE)
use_data(otu_mini_bin_results_svmRadial, overwrite = TRUE)

otu_mini_bin_results_xgbTree <- mikropml::run_ml(otu_mini_bin,
"xgbTree",
Expand All @@ -100,7 +100,7 @@ otu_mini_bin_results_xgbTree <- mikropml::run_ml(otu_mini_bin,
seed = 2019,
cv_times = 2
)
usethis::use_data(otu_mini_bin_results_xgbTree, overwrite = TRUE)
use_data(otu_mini_bin_results_xgbTree, overwrite = TRUE)

otu_mini_bin_results_rpart2 <- mikropml::run_ml(otu_mini_bin,
"rpart2",
Expand Down
Binary file modified data/otu_mini_bin_results_rf.rda
Binary file not shown.
Binary file modified data/otu_mini_cont_results_glmnet.rda
Binary file not shown.
Binary file modified data/otu_mini_multi_results_glmnet.rda
Binary file not shown.
68 changes: 39 additions & 29 deletions docs/dev/articles/introduction.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.