Skip to content

Commit

Permalink
ARROW-17462: [R] Cast scalars to type of field in Expression building (
Browse files Browse the repository at this point in the history
…#13985)

Logic is encapsulated in `wrap_scalars()` in expression.R. Most test updating (that is not linting) is changing some printed output types because `int * 2` now stays `int32`, and the printed ExecPlans don't have as many `cast`s in them. The tests added in `test-dplyr-query.R` are the explicit tests of the feature. 

Authored-by: Neal Richardson <neal.p.richardson@gmail.com>
Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
  • Loading branch information
nealrichardson committed Oct 31, 2022
1 parent 8066c5e commit d045fc5
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 33 deletions.
2 changes: 1 addition & 1 deletion r/R/compute.R
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ register_scalar_function <- function(name, fun, in_type, out_type,
RegisterScalarUDF(name, scalar_function)

# register with dplyr binding (enables its use in mutate(), filter(), etc.)
binding_fun <- function(...) build_expr(name, ...)
binding_fun <- function(...) Expression$create(name, ...)

# inject the value of `name` into the expression to avoid saving this
# execution environment in the binding, which eliminates a warning when the
Expand Down
125 changes: 114 additions & 11 deletions r/R/expression.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ Expression$create <- function(function_name,
args = list(...),
options = empty_named_list()) {
assert_that(is.string(function_name))
assert_that(is_list_of(args, "Expression"), msg = "Expression arguments must be Expression objects")
# Make sure all inputs are Expressions
args <- lapply(args, function(x) {
if (!inherits(x, "Expression")) {
x <- Expression$scalar(x)
}
x
})
expr <- compute___expr__call(function_name, args, options)
if (length(args)) {
expr$schema <- unify_schemas(schemas = lapply(args, function(x) x$schema))
Expand All @@ -187,7 +193,10 @@ Expression$field_ref <- function(name) {
compute___expr__field_ref(name)
}
Expression$scalar <- function(x) {
expr <- compute___expr__scalar(Scalar$create(x))
if (!inherits(x, "Scalar")) {
x <- Scalar$create(x)
}
expr <- compute___expr__scalar(x)
expr$schema <- schema()
expr
}
Expand All @@ -208,21 +217,20 @@ build_expr <- function(FUN,
}
if (FUN == "%in%") {
# Special-case %in%, which is different from the Array function name
value_set <- Array$create(args[[2]])
try(
value_set <- cast_or_parse(value_set, args[[1]]$type()),
silent = TRUE
)

expr <- Expression$create("is_in", args[[1]],
options = list(
# If args[[2]] is already an Arrow object (like a scalar),
# this wouldn't work
value_set = Array$create(args[[2]]),
value_set = value_set,
skip_nulls = TRUE
)
)
} else {
args <- lapply(args, function(x) {
if (!inherits(x, "Expression")) {
x <- Expression$scalar(x)
}
x
})
args <- wrap_scalars(args, FUN)

# In Arrow, "divide" is one function, which does integer division on
# integer inputs and floating-point division on floats
Expand Down Expand Up @@ -258,6 +266,101 @@ build_expr <- function(FUN,
expr
}

wrap_scalars <- function(args, FUN) {
arrow_fun <- .array_function_map[[FUN]] %||% FUN
if (arrow_fun == "if_else") {
# For if_else, the first arg should be a bool Expression, and we don't
# want to consider that when casting the other args to the same type
args[-1] <- wrap_scalars(args[-1], FUN = "")
return(args)
}

is_expr <- map_lgl(args, ~ inherits(., "Expression"))
if (all(is_expr)) {
# No wrapping is required
return(args)
}

args[!is_expr] <- lapply(args[!is_expr], Scalar$create)

# Some special casing by function
# * %/%: we switch behavior based on int vs. dbl in R (see build_expr) so skip
# * binary_repeat, list_element: 2nd arg must be integer, Acero will handle it
if (any(is_expr) && !(arrow_fun %in% c("binary_repeat", "list_element")) && !(FUN %in% "%/%")) {
try(
{
# If the Expression has no Schema embedded, we cannot resolve its
# type here, so this will error, hence the try() wrapping it
# This will also error if length(args[is_expr]) == 0, or
# if there are multiple exprs that do not share a common type.
to_type <- common_type(args[is_expr])
# Try casting to this type, but if the cast fails,
# we'll just keep the original
args[!is_expr] <- lapply(args[!is_expr], cast_or_parse, type = to_type)
},
silent = TRUE
)
}

args[!is_expr] <- lapply(args[!is_expr], Expression$scalar)
args
}

common_type <- function(exprs) {
types <- map(exprs, ~ .$type())
first_type <- types[[1]]
if (length(types) == 1 || all(map_lgl(types, ~ .$Equals(first_type)))) {
# Functions (in our tests) that have multiple exprs to check:
# * case_when
# * pmin/pmax
return(first_type)
}
stop("There is no common type in these expressions")
}

cast_or_parse <- function(x, type) {
to_type_id <- type$id
if (to_type_id %in% c(Type[["DECIMAL128"]], Type[["DECIMAL256"]])) {
# TODO: determine the minimum size of decimal (or integer) required to
# accommodate x
# We would like to keep calculations on decimal if that's what the data has
# so that we don't lose precision. However, there are some limitations
# today, so it makes sense to keep x as double (which is probably is from R)
# and let Acero cast the decimal to double to compute.
# You can specify in your query that x should be decimal or integer if you
# know it to be safe.
# * ARROW-17601: multiply(decimal, decimal) can fail to make output type
return(x)
}

# For most types, just cast.
# But for string -> date/time, we need to call a parsing function
if (x$type_id() %in% c(Type[["STRING"]], Type[["LARGE_STRING"]])) {
if (to_type_id %in% c(Type[["DATE32"]], Type[["DATE64"]])) {
x <- call_function(
"strptime",
x,
options = list(format = "%Y-%m-%d", unit = 0L)
)
} else if (to_type_id == Type[["TIMESTAMP"]]) {
x <- call_function(
"strptime",
x,
options = list(format = "%Y-%m-%d %H:%M:%S", unit = 1L)
)
# R assumes timestamps without timezone specified are
# local timezone while Arrow assumes UTC. For consistency
# with R behavior, specify local timezone here.
x <- call_function(
"assume_timezone",
x,
options = list(timezone = Sys.timezone())
)
}
}
x$cast(type)
}

#' @export
Ops.Expression <- function(e1, e2) {
if (.Generic == "!") {
Expand Down
6 changes: 3 additions & 3 deletions r/tests/testthat/test-dataset-dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ test_that("mutate()", {
chr: string
dbl: double
int: int32
twice: double (multiply_checked(int, 2))
twice: int32 (multiply_checked(int, 2))
* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3))
See $.data for the source Arrow object",
Expand Down Expand Up @@ -219,7 +219,7 @@ test_that("arrange()", {
chr: string
dbl: double
int: int32
twice: double (multiply_checked(int, 2))
twice: int32 (multiply_checked(int, 2))
* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3))
* Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc]
Expand Down Expand Up @@ -368,7 +368,7 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", {
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"ProjectNode.*", # output columns
"FilterNode.*", # filter node
"int > 6.*cast.*", # filtering expressions + auto-casting of part
"int > 6.*", # filtering expressions
"SourceNode" # entry point
)
)
Expand Down
4 changes: 2 additions & 2 deletions r/tests/testthat/test-dplyr-collapse.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ test_that("implicit_schema with mutate", {
words = as.character(int)
) %>%
implicit_schema(),
schema(numbers = float64(), words = utf8())
schema(numbers = int32(), words = utf8())
)
})

Expand Down Expand Up @@ -163,7 +163,7 @@ test_that("Properties of collapsed query", {
"Table (query)
lgl: bool
total: int32
extra: double (multiply_checked(total, 5))
extra: int32 (multiply_checked(total, 5))
See $.data for the source Arrow object",
fixed = TRUE
Expand Down
30 changes: 17 additions & 13 deletions r/tests/testthat/test-dplyr-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,25 +217,29 @@ test_that("filter() with between()", {
filter(dbl >= int, dbl <= dbl2)
)

expect_error(
tbl %>%
record_batch() %>%
compare_dplyr_binding(
.input %>%
filter(between(dbl, 1, "2")) %>%
collect()
collect(),
tbl
)

expect_error(
tbl %>%
record_batch() %>%
compare_dplyr_binding(
.input %>%
filter(between(dbl, 1, NA)) %>%
collect()
collect(),
tbl
)

expect_error(
tbl %>%
record_batch() %>%
filter(between(chr, 1, 2)) %>%
collect()
expect_warning(
compare_dplyr_binding(
.input %>%
filter(between(chr, 1, 2)) %>%
collect(),
tbl
),
# the dplyr version warns:
"NAs introduced by coercion"
)
})

Expand Down
2 changes: 1 addition & 1 deletion r/tests/testthat/test-dplyr-mutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ test_that("print a mutated table", {
print(),
"Table (query)
int: int32
twice: double (multiply_checked(int, 2))
twice: int32 (multiply_checked(int, 2))
See $.data for the source Arrow object",
fixed = TRUE
Expand Down
87 changes: 87 additions & 0 deletions r/tests/testthat/test-dplyr-query.R
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,90 @@ test_that("collect() is identical to compute() %>% collect()", {
collect()
)
})

test_that("Scalars in expressions match the type of the field, if possible", {
tbl_with_datetime <- tbl
tbl_with_datetime$dates <- as.Date("2022-08-28") + 1:10
tbl_with_datetime$times <- lubridate::ymd_hms("2018-10-07 19:04:05") + 1:10
tab <- Table$create(tbl_with_datetime)

# 5 is double in R but is properly interpreted as int, no cast is added
expect_output(
tab %>%
filter(int == 5) %>%
show_exec_plan(),
"int == 5"
)

# Because 5.2 can't cast to int32 without truncation, we pass as is
# and Acero will cast int to float64
expect_output(
tab %>%
filter(int == 5.2) %>%
show_exec_plan(),
"filter=(cast(int, {to_type=double",
fixed = TRUE
)
expect_equal(
tab %>%
filter(int == 5.2) %>%
nrow(),
0
)

# int == string, this works in dplyr and here too
expect_output(
tab %>%
filter(int == "5") %>%
show_exec_plan(),
"int == 5",
fixed = TRUE
)
expect_equal(
tab %>%
filter(int == "5") %>%
nrow(),
1
)

# Strings automatically parsed to date/timestamp
expect_output(
tab %>%
filter(dates > "2022-09-01") %>%
show_exec_plan(),
"dates > 2022-09-01"
)
compare_dplyr_binding(
.input %>%
filter(dates > "2022-09-01") %>%
collect(),
tbl_with_datetime
)

expect_output(
tab %>%
filter(times > "2018-10-07 19:04:05") %>%
show_exec_plan(),
"times > 2018-10-0. ..:..:05"
)
compare_dplyr_binding(
.input %>%
filter(times > "2018-10-07 19:04:05") %>%
collect(),
tbl_with_datetime
)

tab_with_decimal <- tab %>%
mutate(dec = cast(dbl, decimal(15, 2))) %>%
compute()

# This reproduces the issue on ARROW-17601, found in the TPC-H query 1
# In ARROW-17462, we chose not to auto-cast to decimal to avoid that issue
result <- tab_with_decimal %>%
summarize(
tpc_h_1 = sum(dec * (1 - dec) * (1 + dec), na.rm = TRUE),
as_dbl = sum(dbl * (1 - dbl) * (1 + dbl), na.rm = TRUE)
) %>%
collect()
expect_equal(result$tpc_h_1, result$as_dbl)
})
5 changes: 3 additions & 2 deletions r/tests/testthat/test-expression.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ test_that("C++ expressions", {
# Interprets that as a list type
expect_r6_class(f == c(1L, 2L), "Expression")

expect_error(
# Non-Expression inputs are wrapped in Expression$scalar()
expect_equal(
Expression$create("add", 1, 2),
"Expression arguments must be Expression objects"
Expression$create("add", Expression$scalar(1), Expression$scalar(2))
)
})

Expand Down

0 comments on commit d045fc5

Please sign in to comment.