diff --git a/R/performance.R b/R/performance.R index 0bad30a5..afc9e07d 100644 --- a/R/performance.R +++ b/R/performance.R @@ -214,8 +214,8 @@ get_performance_tbl <- function(trained_model, #' @export calc_model_sensspec <- function(trained_model, test_data, outcome_colname = NULL) { # adapted from https://github.com/SchlossLab/2021-08-09_ROCcurves/blob/8e62ff8b6fe1b691450c953a9d93b2c11ce3369a/ROCcurves.Rmd#L95-L109 - outcome_colname = check_outcome_column(test_data, outcome_colname) - pos_outcome = trained_model$levels[1] + outcome_colname <- check_outcome_column(test_data, outcome_colname) + pos_outcome <- trained_model$levels[1] actual <- is_pos <- tp <- fp <- fpr <- NULL probs <- stats::predict(trained_model, newdata = test_data, @@ -389,21 +389,22 @@ NULL #' #' #' calc_baseline_precision(otu_mini_bin, -#' outcome_colname = "dx", -#' pos_outcome = "cancer") +#' outcome_colname = "dx", +#' pos_outcome = "cancer" +#' ) #' #' #' # if you're not sure which outcome was used as the 'positive' outcome during #' # model training, you can access it from the trained model and pass it along: #' calc_baseline_precision(otu_mini_bin, -#' outcome_colname = "dx", -#' pos_outcome = otu_mini_bin_results_glmnet$trained_model$levels[1]) -#' +#' outcome_colname = "dx", +#' pos_outcome = otu_mini_bin_results_glmnet$trained_model$levels[1] +#' ) #' calc_baseline_precision <- function(dataset, outcome_colname = NULL, pos_outcome = NULL) { - outcome_colname = check_outcome_column(dataset, outcome_colname) + outcome_colname <- check_outcome_column(dataset, outcome_colname) npos <- dataset %>% dplyr::filter(!!rlang::sym(outcome_colname) == pos_outcome) %>% nrow() diff --git a/tests/testthat/fixtures/train-multi.R b/tests/testthat/fixtures/train-multi.R index 87913189..9abdbe7e 100644 --- a/tests/testthat/fixtures/train-multi.R +++ b/tests/testthat/fixtures/train-multi.R @@ -6,13 +6,13 @@ future::plan(future::multicore, workers = 8) otu_data_preproc <- mikropml::otu_data_preproc$dat_transformed results_list <- future.apply::future_lapply(seq(100, 102), function(seed) { - run_ml(otu_data_preproc, "glmnet", seed = seed) + run_ml(otu_data_preproc, "glmnet", seed = seed) }, future.seed = TRUE) saveRDS(results_list, testthat::test_path("fixtures", "results_list.Rds")) param_grid <- expand.grid( - seeds = seq(100, 110), - methods = c("glmnet", "rf") + seeds = seq(100, 110), + methods = c("glmnet", "rf") ) results_mtx <- future.apply::future_mapply( function(seed, method) { diff --git a/vignettes/parallel.Rmd b/vignettes/parallel.Rmd index 2f5a89a0..66b69576 100644 --- a/vignettes/parallel.Rmd +++ b/vignettes/parallel.Rmd @@ -82,9 +82,9 @@ results_multi <- future.apply::future_lapply(seq(100, 102), function(seed) { ``` ```{r multi_seeds_load, echo = FALSE} -results_multi <- readRDS(system.file("tests", "testthat", "fixtures", - "results_list.Rds", - package = "mikropml" +results_multi <- readRDS(system.file("tests", "testthat", "fixtures", + "results_list.Rds", + package = "mikropml" )) ``` @@ -140,9 +140,9 @@ results_mtx <- future.apply::future_mapply( ) ``` ```{r results_mtx, echo = FALSE} -results_mtx <- readRDS(system.file("tests", "testthat", "fixtures", - "results_mtx.Rds", - package = "mikropml" +results_mtx <- readRDS(system.file("tests", "testthat", "fixtures", + "results_mtx.Rds", + package = "mikropml" )) ``` @@ -224,8 +224,10 @@ sensspec_1 %>% theme_bw() + theme(legend.title = element_blank()) -baseline_precision_otu <- calc_baseline_precision(otu_data_preproc, - "dx", "cancer") +baseline_precision_otu <- calc_baseline_precision( + otu_data_preproc, + "dx", "cancer" +) sensspec_1 %>% rename(recall = sensitivity) %>% ggplot(aes(x = recall, y = precision, )) +