Skip to content

Commit

Permalink
Merge pull request #329 from SchlossLab/iss-325_bootstrap
Browse files Browse the repository at this point in the history
Create `bootstrap_performance()` for single train/test splits
  • Loading branch information
kelly-sovacool committed Feb 1, 2023
2 parents f73f8a8 + 047360f commit 2a0ec63
Show file tree
Hide file tree
Showing 24 changed files with 747 additions and 46 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/pr_build.yml
Expand Up @@ -26,11 +26,10 @@ jobs:
run: |
git config --local user.email "noreply@github.com"
git config --local user.name "GitHub"
- name: Check
- name: Document
run: |
Rscript -e 'devtools::check()'
git add man/ NAMESPACE
git commit \
Rscript -e 'devtools::document()'
git commit -a \
--author="github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" \
-m '📚 Render Roxygen documentation' || echo "No changes to commit"
- name: Style
Expand All @@ -54,7 +53,7 @@ jobs:
branch: ${{ github.head_ref }}
- name: Docs
run: |
Rscript -e "pkgdown::build_reference(); devtools::build_site()"
Rscript -e "devtools::build_site(lazy = TRUE)"
git add docs
git commit \
--author="github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" \
Expand Down
5 changes: 4 additions & 1 deletion DESCRIPTION
Expand Up @@ -74,22 +74,25 @@ Imports:
utils,
xgboost
Suggests:
assertthat,
doFuture,
foreach,
future,
future.apply,
furrr,
ggplot2,
knitr,
progress,
progressr,
purrr,
rmarkdown,
rsample,
testthat,
tidyr
VignetteBuilder:
knitr
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
Config/testthat/edition: 3
1 change: 1 addition & 0 deletions NAMESPACE
Expand Up @@ -4,6 +4,7 @@ export("!!")
export("%>%")
export(":=")
export(.data)
export(bootstrap_performance)
export(calc_baseline_precision)
export(calc_mean_prc)
export(calc_mean_roc)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
@@ -1,5 +1,8 @@
# mikropml development version

- New function `bootstrap_performance()` allows you to calculate confidence
intervals for the model performance from a single train/test split by
bootstrapping the test set (#329, @kelly-sovacool).
- Minor documentation improvements (#323, @kelly-sovacool).

# mikropml 1.5.0
Expand Down
104 changes: 103 additions & 1 deletion R/performance.R
Expand Up @@ -168,6 +168,7 @@ get_performance_tbl <- function(trained_model,
class_probs,
method,
seed = NA) {
cv_metric <- NULL
test_perf_metrics <- calc_perf_metrics(
test_data,
trained_model,
Expand Down Expand Up @@ -201,11 +202,112 @@ get_performance_tbl <- function(trained_model,
)) %>%
dplyr::rename_with(
function(x) paste0("cv_metric_", perf_metric_name),
.data$cv_metric
cv_metric
) %>%
change_to_num())
}

#' Calculate a bootstrap confidence interval for the performance on a single train/test split
#'
#' Uses [rsample::bootstraps()], [rsample::int_pctl()], and [furrr::future_map()]
#'
#' @param ml_result result returned from a single [run_ml()] call
#' @inheritParams run_ml
#' @param bootstrap_times the number of boostraps to create (default: `10000`)
#' @param alpha the alpha level for the confidence interval (default `0.05` to create a 95% confidence interval)
#'
#' @return a data frame with an estimate (`.estimate`), lower bound (`.lower`),
#' and upper bound (`.upper`) for each performance metric (`term`).
#' @export
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#'
#' @examples
#' bootstrap_performance(otu_mini_bin_results_glmnet, "dx",
#' bootstrap_times = 10, alpha = 0.10
#' )
#' \dontrun{
#' outcome_colname <- "dx"
#' run_ml(otu_mini_bin, "rf", outcome_colname = "dx") %>%
#' bootstrap_performance(outcome_colname,
#' bootstrap_times = 10000,
#' alpha = 0.05
#' )
#' }
bootstrap_performance <- function(ml_result,
outcome_colname,
bootstrap_times = 10000,
alpha = 0.05) {
abort_packages_not_installed("assertthat", "rsample", "furrr")
splits <- perf <- NULL

model <- ml_result$trained_model
test_dat <- ml_result$test_data
outcome_type <- get_outcome_type(test_dat %>% dplyr::pull(outcome_colname))
class_probs <- outcome_type != "continuous"
method <- model$modelInfo$label
seed <- ml_result$performance %>% dplyr::pull(seed)
assertthat::are_equal(length(seed), 1)
return(
rsample::bootstraps(test_dat, times = bootstrap_times) %>%
dplyr::mutate(perf = furrr::future_map(
splits,
~ calc_perf_bootstrap_split(
.x,
trained_model = model,
outcome_colname = outcome_colname,
perf_metric_function = get_perf_metric_fn(outcome_type),
perf_metric_name = model$metric,
class_probs = outcome_type != "continuous",
method = model$trained_model$modelInfo$label,
seed = seed
)
)) %>%
rsample::int_pctl(perf, alpha = alpha)
)
}

#' Calculate performance for a single split from [rsample::bootstraps()]
#'
#' Used by [bootstrap_performance()].
#'
#' @param test_data_split a single bootstrap of the test set from [rsample::bootstraps()]
#' @inheritParams get_performance_tbl
#' @return a long data frame of performance metrics for [rsample::int_pctl()]
#'
#' @keywords internal
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#'
calc_perf_bootstrap_split <- function(test_data_split,
trained_model,
outcome_colname,
perf_metric_function,
perf_metric_name,
class_probs,
method,
seed) {
abort_packages_not_installed("rsample")
return(
get_performance_tbl(
trained_model,
rsample::analysis(test_data_split),
outcome_colname,
perf_metric_function,
perf_metric_name,
class_probs,
method,
seed
) %>%
dplyr::select(-dplyr::all_of(c(method)), -seed) %>%
dplyr::mutate(dplyr::across(dplyr::everything(), as.numeric)) %>%
tidyr::pivot_longer(
dplyr::everything(),
names_to = "term",
values_to = "estimate"
)
)
}


#' @describeIn sensspec Get sensitivity, specificity, and precision for a model.
#'
#' @inheritParams calc_perf_metrics
Expand Down
5 changes: 3 additions & 2 deletions README.md
Expand Up @@ -56,8 +56,9 @@ mamba install -c conda-forge r-mikropml

- Imports: caret, dplyr, e1071, glmnet, kernlab, MLmetrics,
randomForest, rlang, rpart, stats, utils, xgboost
- Suggests: doFuture, foreach, future, future.apply, ggplot2, knitr,
progress, progressr, purrr, rmarkdown, testthat, tidyr
- Suggests: assertthat, doFuture, foreach, future, future.apply, furrr,
ggplot2, knitr, progress, progressr, purrr, rmarkdown, rsample,
testthat, tidyr

## Usage

Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Expand Up @@ -57,6 +57,7 @@ reference:
- sensspec
- compare_models
- permute_p_value
- bootstrap_performance
- title: Plotting helpers
desc: >
Visualize results to help you tune hyperparameters and choose model methods.
Expand Down
6 changes: 5 additions & 1 deletion docs/dev/articles/introduction.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 2a0ec63

Please sign in to comment.