Skip to content

Commit

Permalink
Merge pull request #326 from SchlossLab/iss-324
Browse files Browse the repository at this point in the history
Report confidence interval for permutation feature importance
  • Loading branch information
courtneyarmour committed Feb 15, 2023
2 parents 2a0ec63 + d192c41 commit 6bcc6d2
Show file tree
Hide file tree
Showing 29 changed files with 881 additions and 191 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Expand Up @@ -76,6 +76,7 @@ Imports:
Suggests:
assertthat,
doFuture,
forcats,
foreach,
future,
future.apply,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Expand Up @@ -3,6 +3,10 @@
- New function `bootstrap_performance()` allows you to calculate confidence
intervals for the model performance from a single train/test split by
bootstrapping the test set (#329, @kelly-sovacool).
- Improved output from `find_feature_importance()` (#326, @kelly-sovacool).
- Renamed the column `names` to `feat` to represent each feature or group of correlated features.
- New column `lower` and `upper` to report the bounds of the empirical 95% confidence interval from the permutation test.
See `vignette('parallel')` for an example of plotting feature importance with confidence intervals.
- Minor documentation improvements (#323, @kelly-sovacool).

# mikropml 1.5.0
Expand Down
47 changes: 42 additions & 5 deletions R/feature_importance.R
Expand Up @@ -13,7 +13,7 @@
#' grouped together based on `corr_thresh`.
#'
#' @return Data frame with performance metrics for when each feature (or group
#' of correlated features; `names`) is permuted (`perf_metric`), differences
#' of correlated features; `feat`) is permuted (`perf_metric`), differences
#' between the actual test performance metric on and the permuted performance
#' metric (`perf_metric_diff`; test minus permuted performance), and the
#' p-value (`pvalue`: the probability of obtaining the actual performance
Expand Down Expand Up @@ -173,7 +173,7 @@ get_feature_importance <- function(trained_model, test_data,

return(as.data.frame(imps) %>%
dplyr::mutate(
names = factor(groups),
feat = factor(groups),
method = method,
perf_metric_name = perf_metric_name,
seed = seed
Expand All @@ -191,10 +191,13 @@ get_feature_importance <- function(trained_model, test_data,
#' @param progbar optional progress bar (default: `NULL`)
#' @inheritParams run_ml
#' @inheritParams get_feature_importance
#' @param alpha alpha level for the confidence interval
#' (default: `0.05` to obtain a 95% confidence interval)
#'
#' @return vector of mean permuted performance and mean difference between test
#' and permuted performance (test minus permuted performance)
#' @noRd
#' @keywords internal
#'
#' @author Begüm Topçuoğlu, \email{topcuoglu.begum@@gmail.com}
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
Expand All @@ -203,7 +206,9 @@ find_permuted_perf_metric <- function(test_data, trained_model, outcome_colname,
perf_metric_function, perf_metric_name,
class_probs, feat,
test_perf_value,
nperms = 100, progbar = NULL) {
nperms = 100,
alpha = 0.05,
progbar = NULL) {
# The code below uses a bunch of base R subsetting that doesn't work with tibbles.
# We should probably refactor those to use tidyverse functions instead,
# but for now this is a temporary fix.
Expand Down Expand Up @@ -235,6 +240,38 @@ find_permuted_perf_metric <- function(test_data, trained_model, outcome_colname,
return(c(
perf_metric = mean_perm_perf,
perf_metric_diff = test_perf_value - mean_perm_perf,
pvalue = calc_pvalue(perm_perfs, test_perf_value)
pvalue = calc_pvalue(perm_perfs, test_perf_value),
lower = lower_bound(perm_perfs, alpha),
upper = upper_bound(perm_perfs, alpha)
))
}

#' @describeIn bounds Get the lower bound for an empirical confidence interval
#' @keywords internal
lower_bound <- function(x, alpha) {
x <- sort(x)
return(x[length(x) * alpha / 2])
}

#' @describeIn bounds Get the upper bound for an empirical confidence interval
#' @keywords internal
upper_bound <- function(x, alpha) {
x <- sort(x)
return(x[length(x) - length(x) * alpha / 2])
}

#' @name bounds
#' @title Get the lower and upper bounds for an empirical confidence interval
#'
#' @param x vector of test statistics, such as from permutation tests or bootstraps
#' @inheritParams find_permuted_perf_metric
#'
#' @return the value of the lower or upper bound for the confidence interval
#'
#' @examples
#' \dontrun{
#' x <- 1:10000
#' lower_bound(x, 0.05)
#' upper_bound(x, 0.05)
#' }
NULL
6 changes: 5 additions & 1 deletion R/run_ml.R
Expand Up @@ -174,7 +174,11 @@ run_ml <-
check_cat_feats(dataset %>% dplyr::select(-outcome_colname))
}

dataset <- randomize_feature_order(dataset, outcome_colname)
dataset <- dataset %>%
randomize_feature_order(outcome_colname) %>%
# convert tibble to dataframe to silence warning from caret::train():
# "Warning: Setting row names on a tibble is deprecated.."
as.data.frame()

outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname)

Expand Down
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -56,9 +56,9 @@ mamba install -c conda-forge r-mikropml

- Imports: caret, dplyr, e1071, glmnet, kernlab, MLmetrics,
randomForest, rlang, rpart, stats, utils, xgboost
- Suggests: assertthat, doFuture, foreach, future, future.apply, furrr,
ggplot2, knitr, progress, progressr, purrr, rmarkdown, rsample,
testthat, tidyr
- Suggests: assertthat, doFuture, forcats, foreach, future,
future.apply, furrr, ggplot2, knitr, progress, progressr, purrr,
rmarkdown, rsample, testthat, tidyr

## Usage

Expand Down
12 changes: 6 additions & 6 deletions data-raw/otu_mini_bin.R
Expand Up @@ -4,7 +4,7 @@ library(usethis)

## code to prepare `otu_mini` dataset
otu_mini_bin <- otu_small[, 1:11]
usethis::use_data(otu_mini_bin, overwrite = TRUE)
use_data(otu_mini_bin, overwrite = TRUE)

otu_data_preproc <- preprocess_data(otu_mini_bin, "dx")
use_data(otu_data_preproc)
Expand Down Expand Up @@ -39,7 +39,7 @@ otu_mini_bin_results_glmnet <- mikropml::run_ml(otu_mini_bin, # use built-in hyp
seed = 2019,
cv_times = 2
)
usethis::use_data(otu_mini_bin_results_glmnet, overwrite = TRUE)
use_data(otu_mini_bin_results_glmnet, overwrite = TRUE)

# cv_group <- sample(LETTERS[1:5], nrow(otu_mini_bin_results_glmnet$trained_model$trainingData), replace = TRUE)
cv_group <- c(
Expand Down Expand Up @@ -71,7 +71,7 @@ otu_mini_cv <- define_cv(otu_mini_bin_results_glmnet$trained_model$trainingData,
cv_times = 2,
groups = cv_group
)
usethis::use_data(otu_mini_cv, overwrite = TRUE)
use_data(otu_mini_cv, overwrite = TRUE)

# use built-in hyperparams function for this one
otu_mini_bin_results_rf <- mikropml::run_ml(otu_mini_bin,
Expand All @@ -82,7 +82,7 @@ otu_mini_bin_results_rf <- mikropml::run_ml(otu_mini_bin,
cv_times = 2,
groups = otu_mini_group
)
usethis::use_data(otu_mini_bin_results_rf, overwrite = TRUE)
use_data(otu_mini_bin_results_rf, overwrite = TRUE)

otu_mini_bin_results_svmRadial <- mikropml::run_ml(otu_mini_bin,
"svmRadial",
Expand All @@ -91,7 +91,7 @@ otu_mini_bin_results_svmRadial <- mikropml::run_ml(otu_mini_bin,
seed = 2019,
cv_times = 2
)
usethis::use_data(otu_mini_bin_results_svmRadial, overwrite = TRUE)
use_data(otu_mini_bin_results_svmRadial, overwrite = TRUE)

otu_mini_bin_results_xgbTree <- mikropml::run_ml(otu_mini_bin,
"xgbTree",
Expand All @@ -100,7 +100,7 @@ otu_mini_bin_results_xgbTree <- mikropml::run_ml(otu_mini_bin,
seed = 2019,
cv_times = 2
)
usethis::use_data(otu_mini_bin_results_xgbTree, overwrite = TRUE)
use_data(otu_mini_bin_results_xgbTree, overwrite = TRUE)

otu_mini_bin_results_rpart2 <- mikropml::run_ml(otu_mini_bin,
"rpart2",
Expand Down
Binary file modified data/otu_mini_bin_results_rf.rda
Binary file not shown.
Binary file modified data/otu_mini_cont_results_glmnet.rda
Binary file not shown.
Binary file modified data/otu_mini_multi_results_glmnet.rda
Binary file not shown.
68 changes: 39 additions & 29 deletions docs/dev/articles/introduction.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6bcc6d2

Please sign in to comment.