Skip to content

Commit

Permalink
Fix test and improve messages
Browse files Browse the repository at this point in the history
  • Loading branch information
nealrichardson committed Oct 12, 2022
1 parent d20afba commit 8e99edb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions r/R/dplyr-slice.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ slice_sample.arrow_dplyr_query <- function(.data,
# If we want n rows sampled, we have to convert n to prop, oversample some
# just to make sure we get enough, then head(n)
sampling_n <- missing(prop)
if (missing(prop)) {
if (sampling_n) {
prop <- min(n_to_prop(.data, n) + .05, 1)
}
validate_prop(prop)
Expand Down Expand Up @@ -137,22 +137,22 @@ slice_sample.Dataset <- slice_sample.ArrowTabular <- slice_sample.RecordBatchRea
prop_to_n <- function(.data, prop) {
nrows <- nrow(.data)
if (is.na(nrows)) {
arrow_not_supported("Slicing with `prop` when `nrow()` requires evaluating the query")
arrow_not_supported("Slicing with `prop` when the query has joins or aggregations")
}
validate_prop(prop)
nrows * prop
}

validate_prop <- function(prop) {
if (!is.numeric(prop) || length(prop) != 1 || is.na(prop) || prop < 0 || prop > 1) {
stop("`prop` must be a single numeric value in [0, 1]", call. = FALSE)
stop("`prop` must be a single numeric value between 0 and 1", call. = FALSE)
}
}

n_to_prop <- function(.data, n) {
nrows <- nrow(.data)
if (is.na(nrows)) {
arrow_not_supported("slice_sample() with `n` when `nrow()` requires evaluating the query")
arrow_not_supported("slice_sample() with `n` when the query has joins or aggregations")
}
n / nrows
}
6 changes: 3 additions & 3 deletions r/tests/testthat/test-dplyr-slice.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ test_that("slice_sample, ungrouped", {
# With a larger dataset, we would be more confident to get exactly n
# but with this dataset, we should at least not get >n rows
sampled_n <- tab %>%
slice_sample(prop = .2) %>%
slice_sample(n = 2) %>%
collect() %>%
nrow()
expect_lte(sampled_n, 2)

# Test with dataset, which matters for the UDF HACK
sampled_n <- tab %>%
InMemoryDataset$create() %>%
slice_sample(prop = .2) %>%
slice_sample(n = 2) %>%
collect() %>%
nrow()
expect_lte(sampled_n, 2)
Expand Down Expand Up @@ -158,7 +158,7 @@ test_that("input validation", {
for (p in list("a", -1, 2, c(.01, .02), NA_real_)) {
expect_error(
slice_head(tab, prop = !!p),
"`prop` must be a single numeric value in [0, 1]",
"`prop` must be a single numeric value between 0 and 1",
fixed = TRUE
)
}
Expand Down

0 comments on commit 8e99edb

Please sign in to comment.