Skip to content

Commit

Permalink
Merge pull request #24 from DeclareDesign/profile_and_coverage
Browse files Browse the repository at this point in the history
Resampling speed improvements; continued coverage improvements
  • Loading branch information
graemeblair committed Oct 31, 2017
2 parents 38a2f02 + 910d51b commit ca85333
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 80 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ Suggests:
testthat,
dplyr,
knitr,
rmarkdown
rmarkdown,
data.table
FasterWith: data.table
VignetteBuilder: knitr
26 changes: 10 additions & 16 deletions R/fabricate.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,24 +152,18 @@ fabricate_data_single_level <- function(data = NULL,

if (!is.null(N)) {
if (length(N) != 1) {
if(is.null(ID_label)) {
stop(
"At the top level, you must provide a single number to N."
)
} else {
stop(
"At the top level, ",
ID_label,
", you must provide a single number to N."
)
}
}

if(is.numeric(N) & any(!N%%1 == 0)) {
stop(paste0(
stop(
"At the top level, ",
ifelse(!is.null(ID_label),
paste0(ID_label, ", "),
""),
"you must provide a single number to N"
)
} else if(is.numeric(N) & any(!N%%1 == 0)) {
stop(
"The provided N must be an integer number. Provided N was of type ",
typeof(N)
))
)
}

if(!is.numeric(N)) {
Expand Down
172 changes: 112 additions & 60 deletions R/resample_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,83 +28,135 @@
#'
#'
#' @export
#'

resample_data = function(data, N, ID_labels=NULL) {
# User didn't provide an N or an ID label, it's clear they just want a regular bootstrap
if (missing(N) & is.null(ID_labels)) {
N <- nrow(data)
return(bootstrap_single_level(data, nrow(data), ID_label=NULL))
}
# Mask internal outer_level and use_dt arguments from view.
.resample_data_internal(data, N, ID_labels)
}

# Error handling
if (!is.null(ID_labels) & (length(N) != length(ID_labels))) {
stop(
"If you provide more than one ID_labels to resample data for multilevel data, please provide a vector for N of the same length representing the number to resample at each level."
)
}
.resample_data_internal = function(data, N, ID_labels=NULL, outer_level=1, use_dt = NA) {
# Handle all the data sanity checks in outer_level so we don't have redundant error
# checks further down the recursion.
if(outer_level) {
# Optional usage of data.table to speed up functionality
# Short-circuit on the is.na to only attempt the package load if necessary.
if(is.na(use_dt) && requireNamespace("data.table", quietly=T)) {
use_dt = 1
} else {
use_dt = 0
}

if (any(!ID_labels %in% names(data))) {
stop(
"One or more of the ID labels you provided are not columns in the data frame provided."
)
}
# User didn't provide an N or an ID label, it's clear they just want a regular bootstrap
# of N units by row.
if (missing(N) & is.null(ID_labels)) {
return(bootstrap_single_level(data, dim(data)[1], ID_label=NULL))
}

if(length(N) > 10) {
stop(
"Multi-level bootstrap with more than 10 levels is not advised."
)
# No negative or non-numeric Ns
# Note: this should be rewritten when we implement the "ALL" option for a level.
if (any(!is.numeric(N) | N%%1 | N<=0)) {
stop(
"All specified Ns must be numeric and at least 1."
)
}

# N doesn't match ID labels
if (!is.null(ID_labels) & (length(N) != length(ID_labels))) {
stop(
"If you provide more than one ID_labels to resample data for multilevel data, please provide a vector for N of the same length representing the number to resample at each level."
)
}

# ID_labels looking for some columns we don't have
if (any(!ID_labels %in% names(data))) {
stop(
"One or more of the ID labels you provided are not columns in the data frame provided."
)
}

# Excessive recursion depth
if(length(N) > 10) {
stop(
"Multi-level bootstrap with more than 10 levels is not advised."
)
}
}

# Single level bootstrap with explicit bootstrapping on a particular cluster variable
# this is the inner-most recursion
if(length(N)==1)
{
return(bootstrap_single_level(data,
N[1],
ID_label=ID_labels[1]))
} else {
# Do the current bootstrap level
current_boot_values = unique(data[, ID_labels[1]])
sampled_boot_values = sample(1:length(current_boot_values), N[1], replace=TRUE)
app = 0

# Iterate over each thing chosen at the current level
results_all = lapply(sampled_boot_values, function(i) {
new_results = resample_data(
data[data[, ID_labels[1]] == i, ],
N=N[2:length(N)],
ID_labels=ID_labels[2:length(ID_labels)]
)
})
#res = rbindlist(results_all)
}

# OK, if not, we need to recurse

# Split indices of data frame by the thing we're strapping on
split_data_on_boot_id = split(seq_len(dim(data)[1]), data[,ID_labels[1]])

# Do the current bootstrap level
# sample.int is faster than sample(1:length(.)) or sample(seq.len(length(.))
sampled_boot_values = sample.int(length(split_data_on_boot_id), N[1], replace=TRUE)

# Iterate over each thing chosen at the current level
results_all = lapply(sampled_boot_values, function(i) {
# Get rowids from current bootstrap index, subset based on that
# pass through the recursed Ns and labels, and remind the inner
# layer that it doesn't need to sanity check and we already know
# if data.table is around.
# The list subset on the split is faster than unlisting
.resample_data_internal(
data[split_data_on_boot_id[i][[1]], ],
N=N[2:length(N)],
ID_labels=ID_labels[2:length(ID_labels)],
outer_level=0,
use_dt = use_dt
)
})

# We could probably gain slight efficiency by only doing the rbind on the
# outermost loop.
if(!use_dt) {
# With no data.table, we need to rbind and then remove row names.
# Removing row names is as fast this way as other ways to do the same thing
res = do.call(rbind, results_all)
rownames(res) = NULL
# Return to preceding level
return(res)
} else {
# User has data.table, give them a speed benefit for it
res = data.table::rbindlist(results_all)
# Strip the things that differentiate data.table from data.frame
# so we hand back something identical.
class(res) = "data.frame"
attr(res, ".internal.selfref") = NULL
}
# Return to preceding level
return(res)
}

bootstrap_single_level <- function(data, ID_label = NULL, N) {
if(dim(data)[1] == 0) {
stop("Data being bootstrapped has no rows.")
}
if (is.null(ID_label)) {
# Simple bootstrap
boot_indices <- sample(1:nrow(data), N, replace = TRUE)
} else if(!ID_label %in% colnames(data)) {
stop("ID label provided is not a column in the data being bootstrapped.")
} else {
# Bootstrapping unique values of ID_label (i.e. cluster selection when data
# are observations, not clusters
boot_ids <-
sample(unique(data[, ID_label]), size = N, replace = TRUE)
# Need to do the unlist-apply approach to ensure each row
# is appropriately duplicated. Faster than other ways to map
# cluster ids to row ids.
boot_indices <- unlist(lapply(boot_ids, function(i) {
which(data[, ID_label] == i)
}))
}
# Grab the relevant rows
new_data <- data[boot_indices, , drop = FALSE]
# dim slightly faster than nrow
if(dim(data)[1] == 0) {
stop("Data being bootstrapped has no rows.")
}

if (is.null(ID_label)) {
# Simple bootstrap
return(data[sample(seq_len(dim(data)[1]), N, replace = TRUE), , drop = F])
} else if(!ID_label %in% colnames(data)) {
stop("ID label provided is not a column in the data being bootstrapped.")
}

return(new_data)
# Split data by cluster ID, storing all row indices associated with that cluster ID
# nrow passes through transparently to dim, so this is slightly faster
indices_split = split(seq_len(dim(data)[1]), data[, ID_label])
# Get cluster IDs (not the actual cluster values, the indices of the clusters)
# sample.int is slightly faster than sample(1:length(.)) or sample(seq_len(length(.))
boot_ids = sample.int(length(indices_split), size=N, replace=TRUE)
# Get all row indices associated with every cluster ID combined
boot_indices = unlist(indices_split[boot_ids], recursive=F, use.names=F)
# Only take the indices we want (repeats will be handled properly)
return(data[boot_indices, , drop=F])
}
2 changes: 1 addition & 1 deletion R/variable_creation_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ draw_discrete <-
if (length(breaks) < 3) {
stop("Numeric breaks for ordered data must be of at least length 3.")
}
if (!all(sort(breaks) == breaks)) {
if (is.unsorted(breaks)) {
stop("Numeric breaks must be in ascending order.")
}
if(any(breaks[1] > x) | any(breaks[length(breaks)] < x)) {
Expand Down
50 changes: 48 additions & 2 deletions tests/testthat/test-bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@ test_that("Bootstrap", {
cities = level(N = 5, subways = rnorm(N, mean = gdp))
)

resampled_two_levels <- resample_data(two_levels, N = c(2, 2), ID_labels = c("regions", "cities"))
# Example with data.table codepath
resampled_two_levels <- resample_data(two_levels, N = c(2, 2),
ID_labels = c("regions", "cities"))

# Example without data.table codepath
resampled_two_levels <- .resample_data_internal(two_levels, N = c(2, 2),
ID_labels = c("regions", "cities"),
use_dt=0)

expect_equal(nrow(resampled_two_levels), 4)

Expand All @@ -22,5 +29,44 @@ test_that("Error handling of Bootstrap", {
)

resampled_two_levels <- resample_data(two_levels) # Missing N
expect_error(resample_data(two_levels, c(100, 10), ID_labels = c("Invalid_ID")))

# Invalid ID
expect_error(resample_data(two_levels, c(100, 10), ID_labels = c("Invalid_ID", "Invalid_ID_2")))
# ID length doesn't match n length
expect_error(resample_data(two_levels, c(100, 10), ID_labels = c("regions")))
# Negative N
expect_error(resample_data(two_levels, c(-1), ID_labels = c("regions")))
# Non-numeric
expect_error(resample_data(two_levels, c("hello world"), ID_labels = c("regions")))
})

test_that("Direct bootstrap_single_level", {
two_levels <- fabricate(
regions = level(N = 5, gdp = rnorm(N)),
cities = level(N = sample(1:5), subways = rnorm(N, mean = gdp))
)

null_data = two_levels[two_levels$gdp > 100, ]
# Trying to bootstrap null data
expect_equal(dim(null_data)[1], 0)
expect_error(bootstrap_single_level(null_data, ID_label="regions", N=10))

# Trying to bootstrap single level with an invalid ID.
expect_error(bootstrap_single_level(two_levels, ID_label="invalid-id", N=10))
})

test_that("Extremely high volume data creation.", {
skip("Slows build substantially.")
deep_dive_data = fabricate(
countries = level(N = 100, gdp = rlnorm(N)),
states = level(N = 50, population = rlnorm(N)),
cities = level(N = 50, holiday = runif(N, 1, 365)),
neighborhoods = level(N = 5, stoplights = draw_binary(x=0.5, N)),
houses = level(N = 5, population = runif(N, 1, 5)),
people = level(N = population, sex = ifelse(draw_binary(x=0.5, N), "M", "F"))
)

test_resample = resample_data(deep_dive_data,
ID_labels=c("countries", "states", "cities"),
N=c(100, 50, 50))
})

0 comments on commit ca85333

Please sign in to comment.