Skip to content

Commit

Permalink
Move random UDF definition to inside of slice_sample
Browse files Browse the repository at this point in the history
  • Loading branch information
nealrichardson committed Oct 14, 2022
1 parent 8835808 commit 810f74b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
12 changes: 0 additions & 12 deletions r/R/dplyr-funcs.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ call_binding_agg <- function(fun_name, ...) {
agg_funcs[[fun_name]](...)
}

#' @importFrom stats runif
create_binding_cache <- function() {
# Called in .onLoad()
.cache$docs <- list()
Expand All @@ -161,17 +160,6 @@ create_binding_cache <- function() {
register_bindings_type()
register_bindings_augmented()

# HACK because random() doesn't work (ARROW-17974)
register_scalar_function(
"_random_along",
function(context, x) {
Array$create(runif(length(x)))
},
in_type = schema(x = boolean()),
out_type = float64(),
auto_convert = FALSE
)

# We only create the cache for nse_funcs and not agg_funcs
.cache$functions <- c(as.list(nse_funcs), arrow_funcs)
}
Expand Down
14 changes: 13 additions & 1 deletion r/R/dplyr-slice.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ slice_max.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, with_ties
}
slice_max.Dataset <- slice_max.ArrowTabular <- slice_max.RecordBatchReader <- slice_max.arrow_dplyr_query

#' @importFrom stats runif
slice_sample.arrow_dplyr_query <- function(.data,
...,
n,
Expand Down Expand Up @@ -116,10 +117,21 @@ slice_sample.arrow_dplyr_query <- function(.data,
if (prop < 1) {
.data <- as_adq(.data)
# TODO(ARROW-17974): use Expression$create("random") instead of UDF hack
# HACK: use our UDF to generate random. It needs an input column because
# HACK: use a UDF to generate random. It needs an input column because
# nullary functions don't work, and that column has to be typed. We've
# chosen boolean() type because it's compact and can always be created:
# pick any column and do is.na, that will be boolean.
if (is.null(.cache$functions[["_random_along"]])) {
register_scalar_function(
"_random_along",
function(context, x) {
Array$create(runif(length(x)))
},
in_type = schema(x = boolean()),
out_type = float64(),
auto_convert = FALSE
)
}
# TODO: get an actual FieldRef because the first col could be derived
ref <- Expression$create("is_null", .data$selected_columns[[1]])
expr <- Expression$create("_random_along", ref) < prop
Expand Down

0 comments on commit 810f74b

Please sign in to comment.