Skip to content

Commit

Permalink
ARROW-6982: [R] Add bindings for compare and boolean kernels
Browse files Browse the repository at this point in the history
The scope of this has grown to something larger than the description. In addition to adding bindings to boolean kernels, it also changes how the dplyr filter expressions are generated and evaluated for RecordBatch and Table. Previously, any R function could be used to `filter()` because evaluation happened in R by calling `as.vector` on any Arrays referenced. Now, `filter()` translates R function names to Arrow function names, and evaluation passes the function and arguments to `call_function`. The benefit is that filtering a RecordBatch/Table happens all in Arrow, no pulling data into R and then sending back to Arrow to filter it. The cost is that only functions supported in Arrow can be used now.

In addition to these improvements, the patch includes some extra validation, testing, and print method upgrades.

There are a number of less-than-ideal design choices in here. Some are related to https://issues.apache.org/jira/browse/ARROW-9001 because we have to track/make a guess as to whether the result of `call_function` should be an Array, ChunkedArray, etc.

There's also a bit of duplication here between the two Arrow expression classes, this R-specific parse tree of array/compute expressions and the other Dataset filter expressions. I think that's unavoidable at this time but we should and I expect we will rationalize this in the near future.

Closes #7668 from nealrichardson/r-kernels

Authored-by: Neal Richardson <neal.p.richardson@gmail.com>
Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
  • Loading branch information
nealrichardson authored and kszucs committed Jul 24, 2020
1 parent 8251cc9 commit a284504
Show file tree
Hide file tree
Showing 15 changed files with 322 additions and 84 deletions.
10 changes: 2 additions & 8 deletions r/R/array.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,7 @@ FixedSizeListArray <- R6Class("FixedSizeListArray", inherit = Array,
length.Array <- function(x) x$length()

#' @export
is.na.Array <- function(x) {
if (x$type == null()) {
rep(TRUE, length(x))
} else {
!Array__Mask(x)
}
}
is.na.Array <- function(x) shared_ptr(Array, call_function("is_null", x))

#' @export
as.vector.Array <- function(x, mode) x$as_vector()
Expand All @@ -287,7 +281,7 @@ filter_rows <- function(x, i, keep_na = TRUE, ...) {
nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like
if (inherits(i, "array_expression")) {
# Evaluate it
i <- as.vector(i)
i <- eval_array_expression(i)
}
if (is.logical(i)) {
if (isTRUE(i)) {
Expand Down
4 changes: 0 additions & 4 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion r/R/chunked-array.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ length.ChunkedArray <- function(x) x$length()
as.vector.ChunkedArray <- function(x, mode) x$as_vector()

#' @export
is.na.ChunkedArray <- function(x) unlist(lapply(x$chunks, is.na))
is.na.ChunkedArray <- function(x) shared_ptr(ChunkedArray, call_function("is_null", x))

#' @export
`[.ChunkedArray` <- filter_rows
Expand Down
14 changes: 12 additions & 2 deletions r/R/compute.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@
#' @include chunked-array.R
#' @include scalar.R

call_function <- function(function_name, ..., options = list()) {
call_function <- function(function_name, ..., args = list(...), options = empty_named_list()) {
assert_that(is.string(function_name))
compute__CallFunction(function_name, list(...), options)
assert_that(is.list(options), !is.null(names(options)))

datum_classes <- c("Array", "ChunkedArray", "RecordBatch", "Table", "Scalar")
valid_args <- map_lgl(args, ~inherits(., datum_classes))
if (!all(valid_args)) {
# Lame, just pick one to report
first_bad <- min(which(!valid_args))
stop("Argument ", first_bad, " is of class ", head(class(args[[first_bad]]), 1), " but it must be one of ", oxford_paste(datum_classes, "or"), call. = FALSE)
}

compute__CallFunction(function_name, args, options)
}

#' @export
Expand Down
11 changes: 8 additions & 3 deletions r/R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ print.arrow_dplyr_query <- function(x, ...) {
cat(fields, "\n", sep = "")
cat("\n")
if (!isTRUE(x$filtered_rows)) {
cat("* Filter: ", x$filtered_rows$ToString(), "\n", sep = "")
if (query_on_dataset(x)) {
filter_string <- x$filtered_rows$ToString()
} else {
filter_string <- .format_array_expression(x$filtered_rows)
}
cat("* Filter: ", filter_string, "\n", sep = "")
}
if (length(x$group_by_vars)) {
cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "")
Expand Down Expand Up @@ -202,13 +207,13 @@ filter_mask <- function(.data) {
} else {
comp_func <- function(operator) {
force(operator)
function(e1, e2) array_expression(operator, e1, e2)
function(e1, e2) build_array_expression(operator, e1, e2)
}
var_binder <- function(x) .data$.data[[x]]
}

# First add the functions
func_names <- set_names(c(names(comparison_function_map), "&", "|", "%in%"))
func_names <- set_names(names(.array_function_map))
env_bind(f_env, !!!lapply(func_names, comp_func))
# Then add the column references
# Renaming is handled automatically by the named list
Expand Down
143 changes: 135 additions & 8 deletions r/R/expression.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,148 @@

#' @include arrowExports.R

array_expression <- function(FUN, ...) {
structure(list(fun = FUN, args = list(...)), class = "array_expression")
array_expression <- function(FUN,
...,
args = list(...),
options = empty_named_list(),
result_class = .guess_result_class(args[[1]])) {
structure(
list(
fun = FUN,
args = args,
options = options,
result_class = result_class
),
class = "array_expression"
)
}

#' @export
Ops.Array <- function(e1, e2) array_expression(.Generic, e1, e2)
Ops.Array <- function(e1, e2) {
if (.Generic %in% names(.array_function_map)) {
expr <- build_array_expression(.Generic, e1, e2, result_class = "Array")
eval_array_expression(expr)
} else {
stop("Unsupported operation on Array: ", .Generic, call. = FALSE)
}
}

#' @export
Ops.ChunkedArray <- Ops.Array
Ops.ChunkedArray <- function(e1, e2) {
if (.Generic %in% names(.array_function_map)) {
expr <- build_array_expression(.Generic, e1, e2, result_class = "ChunkedArray")
eval_array_expression(expr)
} else {
stop("Unsupported operation on ChunkedArray: ", .Generic, call. = FALSE)
}
}

#' @export
Ops.array_expression <- Ops.Array
Ops.array_expression <- function(e1, e2) {
if (.Generic == "!") {
build_array_expression(.Generic, e1, result_class = e1$result_class)
} else {
build_array_expression(.Generic, e1, e2, result_class = e1$result_class)
}
}

build_array_expression <- function(.Generic, e1, e2, ...) {
if (.Generic %in% names(.unary_function_map)) {
expr <- array_expression(.unary_function_map[[.Generic]], e1)
} else {
e1 <- .wrap_arrow(e1, .Generic, e2$type)
e2 <- .wrap_arrow(e2, .Generic, e1$type)
expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...)
}
expr
}

.wrap_arrow <- function(arg, fun, type) {
if (!inherits(arg, c("ArrowObject", "array_expression"))) {
# TODO: Array$create if lengths are equal?
# TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float)
if (fun == "%in%") {
arg <- Array$create(arg, type = type)
} else {
arg <- Scalar$create(arg, type = type)
}
}
arg
}

.unary_function_map <- list(
"!" = "invert",
"is.na" = "is_null"
)

.binary_function_map <- list(
"==" = "equal",
"!=" = "not_equal",
">" = "greater",
">=" = "greater_equal",
"<" = "less",
"<=" = "less_equal",
"&" = "and_kleene",
"|" = "or_kleene",
"%in%" = "is_in_meta_binary"
)

.array_function_map <- c(.unary_function_map, .binary_function_map)

.guess_result_class <- function(arg) {
# HACK HACK HACK delete this when call_function returns an ArrowObject itself
if (inherits(arg, "ArrowObject")) {
return(class(arg)[1])
} else if (inherits(arg, "array_expression")) {
return(arg$result_class)
} else {
stop("Not implemented")
}
}

eval_array_expression <- function(x) {
x$args <- lapply(x$args, function (a) {
if (inherits(a, "array_expression")) {
eval_array_expression(a)
} else {
a
}
})
ptr <- call_function(x$fun, args = x$args, options = x$options %||% empty_named_list())
shared_ptr(get(x$result_class), ptr)
}

#' @export
is.na.array_expression <- function(x) array_expression("is.na", x)

#' @export
as.vector.array_expression <- function(x, ...) {
x$args <- lapply(x$args, as.vector)
do.call(x$fun, x$args)
as.vector(eval_array_expression(x))
}

#' @export
print.array_expression <- function(x, ...) print(as.vector(x))
print.array_expression <- function(x, ...) {
cat(.format_array_expression(x), "\n", sep = "")
invisible(x)
}

.format_array_expression <- function(x) {
printed_args <- map_chr(x$args, function(arg) {
if (inherits(arg, "Scalar")) {
deparse(as.vector(arg))
} else if (inherits(arg, "ArrowObject")) {
paste0("<", class(arg)[1], ">")
} else if (inherits(arg, "array_expression")) {
.format_array_expression(arg)
} else {
# Should not happen
deparse(arg)
}
})
# Prune this for readability
function_name <- sub("_kleene", "", x$fun)
paste0(function_name, "(", paste(printed_args, collapse = ", "), ")")
}

###########

Expand Down Expand Up @@ -130,6 +248,15 @@ make_expression <- function(operator, e1, e2) {
# In doesn't take Scalar, it takes Array
return(Expression$in_(e1, e2))
}

# Handle unary functions before touching e2
if (operator == "is.na") {
return(is.na(e1))
}
if (operator == "!") {
return(Expression$not(e1))
}

# Check for non-expressions and convert to Expressions
if (!inherits(e1, "Expression")) {
e1 <- Expression$scalar(e1)
Expand Down
1 change: 1 addition & 0 deletions r/R/record-batch.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ RecordBatch <- R6Class("RecordBatch", inherit = ArrowObject,
if (is.logical(i)) {
i <- Array$create(i)
}
assert_that(is.Array(i, "bool"))
shared_ptr(RecordBatch, call_function("filter", self, i, options = list(keep_na = keep_na)))
},
serialize = function() ipc___SerializeRecordBatch__Raw(self),
Expand Down
16 changes: 0 additions & 16 deletions r/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,6 @@ std::shared_ptr<arrow::Array> Array__View(const std::shared_ptr<arrow::Array>& a
return ValueOrStop(array->View(type));
}

// [[arrow::export]]
LogicalVector Array__Mask(const std::shared_ptr<arrow::Array>& array) {
if (array->null_count() == 0) {
return LogicalVector(array->length(), true);
}

auto n = array->length();
LogicalVector res(no_init(n));
arrow::internal::BitmapReader bitmap_reader(array->null_bitmap()->data(),
array->offset(), n);
for (int64_t i = 0; i < n; i++, bitmap_reader.Next()) {
res[i] = bitmap_reader.IsSet();
}
return res;
}

// [[arrow::export]]
void Array__Validate(const std::shared_ptr<arrow::Array>& array) {
StopIfNotOk(array->Validate());
Expand Down
16 changes: 0 additions & 16 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions r/tests/testthat/test-Array.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ expect_array_roundtrip <- function(x, type, as = NULL) {
# TODO: revisit how missingness works with ListArrays
# R list objects don't handle missingness the same way as other vectors.
# Is there some vctrs thing we should do on the roundtrip back to R?
expect_identical(is.na(a), is.na(x))
expect_equal(as.vector(is.na(a)), is.na(x))
}
expect_equivalent(as.vector(a), x)
# Make sure the storage mode is the same on roundtrip (esp. integer vs. numeric)
Expand All @@ -37,7 +37,7 @@ expect_array_roundtrip <- function(x, type, as = NULL) {
expect_type_equal(a_sliced$type, type)
expect_identical(length(a_sliced), length(x_sliced))
if (!inherits(type, c("ListType", "LargeListType"))) {
expect_identical(is.na(a_sliced), is.na(x_sliced))
expect_equal(as.vector(is.na(a_sliced)), is.na(x_sliced))
}
expect_equivalent(as.vector(a_sliced), x_sliced)
}
Expand Down Expand Up @@ -182,8 +182,8 @@ test_that("Array supports NA", {
expect_true(x_int$IsNull(10L))
expect_true(x_dbl$IsNull(10))

expect_equal(is.na(x_int), c(rep(FALSE, 10), TRUE))
expect_equal(is.na(x_dbl), c(rep(FALSE, 10), TRUE))
expect_equal(as.vector(is.na(x_int)), c(rep(FALSE, 10), TRUE))
expect_equal(as.vector(is.na(x_dbl)), c(rep(FALSE, 10), TRUE))

# Input validation
expect_error(x_int$IsValid("ten"), class = "Rcpp::not_compatible")
Expand Down Expand Up @@ -354,7 +354,7 @@ test_that("integer types casts (ARROW-3741)", {
for (type in c(int_types, uint_types)) {
casted <- a$cast(type)
expect_equal(casted$type, type)
expect_identical(is.na(casted), c(rep(FALSE, 10), TRUE))
expect_identical(as.vector(is.na(casted)), c(rep(FALSE, 10), TRUE))
}
})

Expand All @@ -372,7 +372,7 @@ test_that("float types casts (ARROW-3741)", {
for (type in float_types) {
casted <- a$cast(type)
expect_equal(casted$type, type)
expect_identical(is.na(casted), c(rep(FALSE, 3), TRUE))
expect_identical(as.vector(is.na(casted)), c(rep(FALSE, 3), TRUE))
expect_identical(as.vector(casted), x)
}
})
Expand Down

0 comments on commit a284504

Please sign in to comment.