Skip to content

Commit

Permalink
Merge branch 'iss-262_roc' of https://github.com/SchlossLab/mikropml
Browse files Browse the repository at this point in the history
…into iss-262_roc
  • Loading branch information
kelly-sovacool committed Jan 8, 2023
2 parents e90a24c + 3f1140f commit efac7a3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 19 deletions.
17 changes: 9 additions & 8 deletions R/performance.R
Expand Up @@ -214,8 +214,8 @@ get_performance_tbl <- function(trained_model,
#' @export
calc_model_sensspec <- function(trained_model, test_data, outcome_colname = NULL) {
# adapted from https://github.com/SchlossLab/2021-08-09_ROCcurves/blob/8e62ff8b6fe1b691450c953a9d93b2c11ce3369a/ROCcurves.Rmd#L95-L109
outcome_colname = check_outcome_column(test_data, outcome_colname)
pos_outcome = trained_model$levels[1]
outcome_colname <- check_outcome_column(test_data, outcome_colname)
pos_outcome <- trained_model$levels[1]
actual <- is_pos <- tp <- fp <- fpr <- NULL
probs <- stats::predict(trained_model,
newdata = test_data,
Expand Down Expand Up @@ -389,21 +389,22 @@ NULL
#'
#'
#' calc_baseline_precision(otu_mini_bin,
#' outcome_colname = "dx",
#' pos_outcome = "cancer")
#' outcome_colname = "dx",
#' pos_outcome = "cancer"
#' )
#'
#'
#' # if you're not sure which outcome was used as the 'positive' outcome during
#' # model training, you can access it from the trained model and pass it along:
#' calc_baseline_precision(otu_mini_bin,
#' outcome_colname = "dx",
#' pos_outcome = otu_mini_bin_results_glmnet$trained_model$levels[1])
#'
#' outcome_colname = "dx",
#' pos_outcome = otu_mini_bin_results_glmnet$trained_model$levels[1]
#' )
#'
calc_baseline_precision <- function(dataset,
outcome_colname = NULL,
pos_outcome = NULL) {
outcome_colname = check_outcome_column(dataset, outcome_colname)
outcome_colname <- check_outcome_column(dataset, outcome_colname)
npos <- dataset %>%
dplyr::filter(!!rlang::sym(outcome_colname) == pos_outcome) %>%
nrow()
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/fixtures/train-multi.R
Expand Up @@ -6,13 +6,13 @@ future::plan(future::multicore, workers = 8)
otu_data_preproc <- mikropml::otu_data_preproc$dat_transformed

results_list <- future.apply::future_lapply(seq(100, 102), function(seed) {
run_ml(otu_data_preproc, "glmnet", seed = seed)
run_ml(otu_data_preproc, "glmnet", seed = seed)
}, future.seed = TRUE)
saveRDS(results_list, testthat::test_path("fixtures", "results_list.Rds"))

param_grid <- expand.grid(
seeds = seq(100, 110),
methods = c("glmnet", "rf")
seeds = seq(100, 110),
methods = c("glmnet", "rf")
)
results_mtx <- future.apply::future_mapply(
function(seed, method) {
Expand Down
18 changes: 10 additions & 8 deletions vignettes/parallel.Rmd
Expand Up @@ -82,9 +82,9 @@ results_multi <- future.apply::future_lapply(seq(100, 102), function(seed) {
```

```{r multi_seeds_load, echo = FALSE}
results_multi <- readRDS(system.file("tests", "testthat", "fixtures",
"results_list.Rds",
package = "mikropml"
results_multi <- readRDS(system.file("tests", "testthat", "fixtures",
"results_list.Rds",
package = "mikropml"
))
```

Expand Down Expand Up @@ -140,9 +140,9 @@ results_mtx <- future.apply::future_mapply(
)
```
```{r results_mtx, echo = FALSE}
results_mtx <- readRDS(system.file("tests", "testthat", "fixtures",
"results_mtx.Rds",
package = "mikropml"
results_mtx <- readRDS(system.file("tests", "testthat", "fixtures",
"results_mtx.Rds",
package = "mikropml"
))
```

Expand Down Expand Up @@ -224,8 +224,10 @@ sensspec_1 %>%
theme_bw() +
theme(legend.title = element_blank())
baseline_precision_otu <- calc_baseline_precision(otu_data_preproc,
"dx", "cancer")
baseline_precision_otu <- calc_baseline_precision(
otu_data_preproc,
"dx", "cancer"
)
sensspec_1 %>%
rename(recall = sensitivity) %>%
ggplot(aes(x = recall, y = precision, )) +
Expand Down

0 comments on commit efac7a3

Please sign in to comment.