diff --git a/DESCRIPTION b/DESCRIPTION index 2077300..aab7bdb 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -20,5 +20,7 @@ Suggests: testthat, dplyr, knitr, - rmarkdown + rmarkdown, + data.table +FasterWith: data.table VignetteBuilder: knitr diff --git a/R/fabricate.R b/R/fabricate.R index a54026d..1cd97db 100644 --- a/R/fabricate.R +++ b/R/fabricate.R @@ -152,24 +152,18 @@ fabricate_data_single_level <- function(data = NULL, if (!is.null(N)) { if (length(N) != 1) { - if(is.null(ID_label)) { - stop( - "At the top level, you must provide a single number to N." - ) - } else { - stop( - "At the top level, ", - ID_label, - ", you must provide a single number to N." - ) - } - } - - if(is.numeric(N) & any(!N%%1 == 0)) { - stop(paste0( + stop( + "At the top level, ", + ifelse(!is.null(ID_label), + paste0(ID_label, ", "), + ""), + "you must provide a single number to N" + ) + } else if(is.numeric(N) & any(!N%%1 == 0)) { + stop( "The provided N must be an integer number. Provided N was of type ", typeof(N) - )) + ) } if(!is.numeric(N)) { diff --git a/R/resample_data.R b/R/resample_data.R index ee385b0..655fe4a 100644 --- a/R/resample_data.R +++ b/R/resample_data.R @@ -28,83 +28,135 @@ #' #' #' @export +#' + resample_data = function(data, N, ID_labels=NULL) { - # User didn't provide an N or an ID label, it's clear they just want a regular bootstrap - if (missing(N) & is.null(ID_labels)) { - N <- nrow(data) - return(bootstrap_single_level(data, nrow(data), ID_label=NULL)) - } + # Mask internal outer_level and use_dt arguments from view. + .resample_data_internal(data, N, ID_labels) +} - # Error handling - if (!is.null(ID_labels) & (length(N) != length(ID_labels))) { - stop( - "If you provide more than one ID_labels to resample data for multilevel data, please provide a vector for N of the same length representing the number to resample at each level." - ) - } +.resample_data_internal = function(data, N, ID_labels=NULL, outer_level=1, use_dt = NA) { + # Handle all the data sanity checks in outer_level so we don't have redundant error + # checks further down the recursion. + if(outer_level) { + # Optional usage of data.table to speed up functionality + # Short-circuit on the is.na to only attempt the package load if necessary. + if(is.na(use_dt) && requireNamespace("data.table", quietly=T)) { + use_dt = 1 + } else { + use_dt = 0 + } - if (any(!ID_labels %in% names(data))) { - stop( - "One or more of the ID labels you provided are not columns in the data frame provided." - ) - } + # User didn't provide an N or an ID label, it's clear they just want a regular bootstrap + # of N units by row. + if (missing(N) & is.null(ID_labels)) { + return(bootstrap_single_level(data, dim(data)[1], ID_label=NULL)) + } - if(length(N) > 10) { - stop( - "Multi-level bootstrap with more than 10 levels is not advised." - ) + # No negative or non-numeric Ns + # Note: this should be rewritten when we implement the "ALL" option for a level. + if (any(!is.numeric(N) | N%%1 | N<=0)) { + stop( + "All specified Ns must be numeric and at least 1." + ) + } + + # N doesn't match ID labels + if (!is.null(ID_labels) & (length(N) != length(ID_labels))) { + stop( + "If you provide more than one ID_labels to resample data for multilevel data, please provide a vector for N of the same length representing the number to resample at each level." + ) + } + + # ID_labels looking for some columns we don't have + if (any(!ID_labels %in% names(data))) { + stop( + "One or more of the ID labels you provided are not columns in the data frame provided." + ) + } + + # Excessive recursion depth + if(length(N) > 10) { + stop( + "Multi-level bootstrap with more than 10 levels is not advised." + ) + } } # Single level bootstrap with explicit bootstrapping on a particular cluster variable + # this is the inner-most recursion if(length(N)==1) { return(bootstrap_single_level(data, N[1], ID_label=ID_labels[1])) - } else { - # Do the current bootstrap level - current_boot_values = unique(data[, ID_labels[1]]) - sampled_boot_values = sample(1:length(current_boot_values), N[1], replace=TRUE) - app = 0 - - # Iterate over each thing chosen at the current level - results_all = lapply(sampled_boot_values, function(i) { - new_results = resample_data( - data[data[, ID_labels[1]] == i, ], - N=N[2:length(N)], - ID_labels=ID_labels[2:length(ID_labels)] - ) - }) - #res = rbindlist(results_all) + } + + # OK, if not, we need to recurse + + # Split indices of data frame by the thing we're strapping on + split_data_on_boot_id = split(seq_len(dim(data)[1]), data[,ID_labels[1]]) + + # Do the current bootstrap level + # sample.int is faster than sample(1:length(.)) or sample(seq.len(length(.)) + sampled_boot_values = sample.int(length(split_data_on_boot_id), N[1], replace=TRUE) + + # Iterate over each thing chosen at the current level + results_all = lapply(sampled_boot_values, function(i) { + # Get rowids from current bootstrap index, subset based on that + # pass through the recursed Ns and labels, and remind the inner + # layer that it doesn't need to sanity check and we already know + # if data.table is around. + # The list subset on the split is faster than unlisting + .resample_data_internal( + data[split_data_on_boot_id[i][[1]], ], + N=N[2:length(N)], + ID_labels=ID_labels[2:length(ID_labels)], + outer_level=0, + use_dt = use_dt + ) + }) + + # We could probably gain slight efficiency by only doing the rbind on the + # outermost loop. + if(!use_dt) { + # With no data.table, we need to rbind and then remove row names. + # Removing row names is as fast this way as other ways to do the same thing res = do.call(rbind, results_all) rownames(res) = NULL - # Return to preceding level - return(res) + } else { + # User has data.table, give them a speed benefit for it + res = data.table::rbindlist(results_all) + # Strip the things that differentiate data.table from data.frame + # so we hand back something identical. + class(res) = "data.frame" + attr(res, ".internal.selfref") = NULL } + # Return to preceding level + return(res) } bootstrap_single_level <- function(data, ID_label = NULL, N) { - if(dim(data)[1] == 0) { - stop("Data being bootstrapped has no rows.") - } - if (is.null(ID_label)) { - # Simple bootstrap - boot_indices <- sample(1:nrow(data), N, replace = TRUE) - } else if(!ID_label %in% colnames(data)) { - stop("ID label provided is not a column in the data being bootstrapped.") - } else { - # Bootstrapping unique values of ID_label (i.e. cluster selection when data - # are observations, not clusters - boot_ids <- - sample(unique(data[, ID_label]), size = N, replace = TRUE) - # Need to do the unlist-apply approach to ensure each row - # is appropriately duplicated. Faster than other ways to map - # cluster ids to row ids. - boot_indices <- unlist(lapply(boot_ids, function(i) { - which(data[, ID_label] == i) - })) - } - # Grab the relevant rows - new_data <- data[boot_indices, , drop = FALSE] + # dim slightly faster than nrow + if(dim(data)[1] == 0) { + stop("Data being bootstrapped has no rows.") + } + + if (is.null(ID_label)) { + # Simple bootstrap + return(data[sample(seq_len(dim(data)[1]), N, replace = TRUE), , drop = F]) + } else if(!ID_label %in% colnames(data)) { + stop("ID label provided is not a column in the data being bootstrapped.") + } - return(new_data) + # Split data by cluster ID, storing all row indices associated with that cluster ID + # nrow passes through transparently to dim, so this is slightly faster + indices_split = split(seq_len(dim(data)[1]), data[, ID_label]) + # Get cluster IDs (not the actual cluster values, the indices of the clusters) + # sample.int is slightly faster than sample(1:length(.)) or sample(seq_len(length(.)) + boot_ids = sample.int(length(indices_split), size=N, replace=TRUE) + # Get all row indices associated with every cluster ID combined + boot_indices = unlist(indices_split[boot_ids], recursive=F, use.names=F) + # Only take the indices we want (repeats will be handled properly) + return(data[boot_indices, , drop=F]) } diff --git a/R/variable_creation_functions.R b/R/variable_creation_functions.R index bb2de0c..848b33d 100644 --- a/R/variable_creation_functions.R +++ b/R/variable_creation_functions.R @@ -154,7 +154,7 @@ draw_discrete <- if (length(breaks) < 3) { stop("Numeric breaks for ordered data must be of at least length 3.") } - if (!all(sort(breaks) == breaks)) { + if (is.unsorted(breaks)) { stop("Numeric breaks must be in ascending order.") } if(any(breaks[1] > x) | any(breaks[length(breaks)] < x)) { diff --git a/tests/testthat/test-bootstrap.R b/tests/testthat/test-bootstrap.R index f65bb50..89d5f10 100644 --- a/tests/testthat/test-bootstrap.R +++ b/tests/testthat/test-bootstrap.R @@ -6,7 +6,14 @@ test_that("Bootstrap", { cities = level(N = 5, subways = rnorm(N, mean = gdp)) ) - resampled_two_levels <- resample_data(two_levels, N = c(2, 2), ID_labels = c("regions", "cities")) + # Example with data.table codepath + resampled_two_levels <- resample_data(two_levels, N = c(2, 2), + ID_labels = c("regions", "cities")) + + # Example without data.table codepath + resampled_two_levels <- .resample_data_internal(two_levels, N = c(2, 2), + ID_labels = c("regions", "cities"), + use_dt=0) expect_equal(nrow(resampled_two_levels), 4) @@ -22,5 +29,44 @@ test_that("Error handling of Bootstrap", { ) resampled_two_levels <- resample_data(two_levels) # Missing N - expect_error(resample_data(two_levels, c(100, 10), ID_labels = c("Invalid_ID"))) + + # Invalid ID + expect_error(resample_data(two_levels, c(100, 10), ID_labels = c("Invalid_ID", "Invalid_ID_2"))) + # ID length doesn't match n length + expect_error(resample_data(two_levels, c(100, 10), ID_labels = c("regions"))) + # Negative N + expect_error(resample_data(two_levels, c(-1), ID_labels = c("regions"))) + # Non-numeric + expect_error(resample_data(two_levels, c("hello world"), ID_labels = c("regions"))) +}) + +test_that("Direct bootstrap_single_level", { + two_levels <- fabricate( + regions = level(N = 5, gdp = rnorm(N)), + cities = level(N = sample(1:5), subways = rnorm(N, mean = gdp)) + ) + + null_data = two_levels[two_levels$gdp > 100, ] + # Trying to bootstrap null data + expect_equal(dim(null_data)[1], 0) + expect_error(bootstrap_single_level(null_data, ID_label="regions", N=10)) + + # Trying to bootstrap single level with an invalid ID. + expect_error(bootstrap_single_level(two_levels, ID_label="invalid-id", N=10)) +}) + +test_that("Extremely high volume data creation.", { + skip("Slows build substantially.") + deep_dive_data = fabricate( + countries = level(N = 100, gdp = rlnorm(N)), + states = level(N = 50, population = rlnorm(N)), + cities = level(N = 50, holiday = runif(N, 1, 365)), + neighborhoods = level(N = 5, stoplights = draw_binary(x=0.5, N)), + houses = level(N = 5, population = runif(N, 1, 5)), + people = level(N = population, sex = ifelse(draw_binary(x=0.5, N), "M", "F")) + ) + + test_resample = resample_data(deep_dive_data, + ID_labels=c("countries", "states", "cities"), + N=c(100, 50, 50)) })