Skip to content

Commit

Permalink
another layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Polkas committed Oct 17, 2023
1 parent 9ea143f commit b496210
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions R/cat2cat_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,10 @@ delayed_package_load <- function(package, name) {
#' )
#' mappings <- list(trans = trans, direction = "backward")
#' res <- cat2cat_ml_run( mappings, ml_setup, test_prop = 0.2)
#' mean(unlist(res), na.rm = TRUE)
#' sum(is.na(res)) / length(res)
#' # Average accurecy - please take into account it is multi-level classification
#' mean(unlist(lapply(res, function(x) x$acc)), na.rm = T)
#' # How often accurecy is bigger than naive guess
#' mean(unlist(lapply(res, function(x) x$naive < x$acc)), na.rm = T)
#' }
#'
cat2cat_ml_run <- function(mappings, ml, ...) {
Expand Down Expand Up @@ -239,25 +241,28 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
)

res <- list()
res_dummy <- list()
for (cat in unique(names(mapp))) {
try(
{

matched_cat <- mapp[[match(cat, names(mapp))]]
res_dummy <- c(list(1/length(matched_cat)), res_dummy)
g_name <- paste(matched_cat, collapse = "&")
res[[g_name]][["ncat"]] <- length(matched_cat)
res[[g_name]][["naive"]] <- 1 / length(matched_cat)
res[[g_name]][["acc"]] <- NA

data_small_g <- do.call(rbind, train_g[matched_cat])

if (isTRUE(is.null(data_small_g) || nrow(data_small_g) < 5 || length(matched_cat) < 2)) {
res <- c(list(NA), res)
next
}


index_tt <- sample(c(0, 1), nrow(data_small_g), prob = c(1 - elargs$test_prop, elargs$test_prop), replace = TRUE)
data_test_small <- data_small_g[index_tt == 1, ]
data_train_small <- data_small_g[index_tt == 0, ]

if (isTRUE(nrow(data_test_small) == 0 || nrow(data_train_small) < 5)) {
res <- c(list(NA), res)
next
}

Expand Down Expand Up @@ -304,10 +309,11 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
)$class
}
}
res <- c(list(mean(pred == data_test_small[[ml$cat_var]])), res)
res[[g_name]][["acc"]] <- mean(pred == data_test_small[[ml$cat_var]])

}, silent = TRUE
)
}

list(res, res_dummy)
res
}

0 comments on commit b496210

Please sign in to comment.