Skip to content

Commit

Permalink
Preserve factor levels when given
Browse files Browse the repository at this point in the history
To ensure the user's preferred positive class is preserved.
We should probably add an argument for this?
  • Loading branch information
kelly-sovacool committed May 29, 2023
1 parent 3dcc9bc commit eba7ded
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion R/performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,14 @@ calc_perf_metrics <- function(test_data, trained_model, outcome_colname, perf_me
if (class_probs) pred_type <- "prob"
preds <- stats::predict(trained_model, test_data, type = pred_type)
if (class_probs) {
uniq_obs <- unique(c(test_data %>% dplyr::pull(outcome_colname), as.character(trained_model$pred$obs)))
if (is.factor(test_data %>% dplyr::pull(outcome_colname))) {
uniq_obs <- test_data %>% dplyr::pull(outcome_colname) %>% levels()
} else {
uniq_obs <- unique(c(test_data %>% dplyr::pull(outcome_colname),
as.character(trained_model$pred$obs)
)
)
}
obs <- factor(test_data %>% dplyr::pull(outcome_colname), levels = uniq_obs)
pred_class <- factor(names(preds)[apply(preds, 1, which.max)], levels = uniq_obs)
perf_met <- perf_metric_function(data.frame(obs = obs, pred = pred_class, preds), lev = uniq_obs)
Expand Down

0 comments on commit eba7ded

Please sign in to comment.