diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 95c1405869836..a728be37734f7 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -31,6 +31,7 @@ Biarch: true Imports: assertthat, bit64 (>= 0.9-7), + glue, methods, purrr, R6, @@ -91,6 +92,7 @@ Collate: 'dataset-scan.R' 'dataset-write.R' 'dictionary.R' + 'dplyr-across.R' 'dplyr-arrange.R' 'dplyr-collect.R' 'dplyr-count.R' diff --git a/r/NAMESPACE b/r/NAMESPACE index c4c18ba16d744..49db309b8e862 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -390,6 +390,7 @@ importFrom(assertthat,assert_that) importFrom(assertthat,is.string) importFrom(bit64,print.integer64) importFrom(bit64,str.integer64) +importFrom(glue,glue) importFrom(methods,as) importFrom(purrr,as_mapper) importFrom(purrr,flatten) @@ -413,6 +414,7 @@ importFrom(rlang,as_function) importFrom(rlang,as_label) importFrom(rlang,as_quosure) importFrom(rlang,call2) +importFrom(rlang,call_args) importFrom(rlang,caller_env) importFrom(rlang,dots_n) importFrom(rlang,enexpr) @@ -425,20 +427,25 @@ importFrom(rlang,eval_tidy) importFrom(rlang,exec) importFrom(rlang,expr) importFrom(rlang,is_bare_character) +importFrom(rlang,is_call) importFrom(rlang,is_character) importFrom(rlang,is_empty) importFrom(rlang,is_false) +importFrom(rlang,is_formula) importFrom(rlang,is_integerish) importFrom(rlang,is_interactive) importFrom(rlang,is_list) importFrom(rlang,is_quosure) +importFrom(rlang,is_symbol) importFrom(rlang,list2) importFrom(rlang,new_data_mask) importFrom(rlang,new_environment) importFrom(rlang,quo_get_env) importFrom(rlang,quo_get_expr) +importFrom(rlang,quo_is_call) importFrom(rlang,quo_is_null) importFrom(rlang,quo_name) +importFrom(rlang,quo_set_env) importFrom(rlang,quo_set_expr) importFrom(rlang,quos) importFrom(rlang,seq2) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index f3e0b817d5f42..e8aa93f95346a 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -23,8 +23,10 @@ #' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind set_names exec #' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr .data seq2 is_interactive #' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure -#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match +#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match is_symbol is_call call_args +#' @importFrom rlang quo_set_env quo_get_env is_formula quo_is_call #' @importFrom tidyselect vars_pull vars_rename vars_select eval_select +#' @importFrom glue glue #' @useDynLib arrow, .registration = TRUE #' @keywords internal "_PACKAGE" diff --git a/r/R/dplyr-across.R b/r/R/dplyr-across.R new file mode 100644 index 0000000000000..01a9262b81ec5 --- /dev/null +++ b/r/R/dplyr-across.R @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +expand_across <- function(.data, quos_in) { + quos_out <- list() + # retrieve items using their values to preserve naming of quos other than across + for (quo_i in seq_along(quos_in)) { + quo_in <- quos_in[quo_i] + quo_expr <- quo_get_expr(quo_in[[1]]) + quo_env <- quo_get_env(quo_in[[1]]) + + if (is_call(quo_expr, "across")) { + new_quos <- list() + + across_call <- match.call( + definition = dplyr::across, + call = quo_expr, + expand.dots = FALSE, + envir = quo_env + ) + + if (!all(names(across_call[-1]) %in% c(".cols", ".fns", ".names"))) { + abort("`...` argument to `across()` is deprecated in dplyr and not supported in Arrow") + } + + if (!is.null(across_call[[".cols"]])) { + cols <- across_call[[".cols"]] + } else { + cols <- quote(everything()) + } + + setup <- across_setup( + cols = !!as_quosure(cols, quo_env), + fns = across_call[[".fns"]], + names = across_call[[".names"]], + .caller_env = quo_env, + mask = .data, + inline = TRUE + ) + + if (!is_list(setup$fns) && !is.null(setup$fns) && as.character(setup$fns)[[1]] == "~") { + abort( + paste( + "purrr-style lambda functions as `.fns` argument to `across()`", + "not yet supported in Arrow" + ) + ) + } + + new_quos <- quosures_from_setup(setup, quo_env) + + quos_out <- append(quos_out, new_quos) + } else { + quos_out <- append(quos_out, quo_in) + } + } + + quos_out +} + +# given a named list of functions and column names, create a list of new quosures +quosures_from_setup <- function(setup, quo_env) { + if (!is.null(setup$fns)) { + func_list_full <- rep(setup$fns, length(setup$vars)) + cols_list_full <- rep(setup$vars, each = length(setup$fns)) + + # get new quosures + new_quo_list <- map2( + func_list_full, cols_list_full, + ~ quo(!!call2(.x, sym(.y))) + ) + } else { + # if there's no functions, just map to variables themselves + new_quo_list <- map( + setup$vars, + ~ quo(!!sym(.x)) + ) + } + + quosures <- set_names(new_quo_list, setup$names) + map(quosures, ~ quo_set_env(.x, quo_env)) +} + +across_setup <- function(cols, fns, names, .caller_env, mask, inline = FALSE) { + cols <- enquo(cols) + + vars <- names(dplyr::select(mask, !!cols)) + + if (is.null(fns)) { + if (!is.null(names)) { + glue_mask <- across_glue_mask(.caller_env, .col = vars, .fn = "1") + names <- vctrs::vec_as_names(glue::glue(names, .envir = glue_mask), repair = "check_unique") + } else { + names <- vars + } + + value <- list(vars = vars, fns = fns, names = names) + return(value) + } + + # apply `.names` smart default + if (is.function(fns) || is_formula(fns) || is.name(fns)) { + names <- names %||% "{.col}" + fns <- list("1" = fns) + } else { + names <- names %||% "{.col}_{.fn}" + fns <- call_args(fns) + } + + if (any(map_lgl(fns, is_formula))) { + abort( + paste( + "purrr-style lambda functions as `.fns` argument to `across()`", + "not yet supported in Arrow" + ) + ) + } + + if (!is.list(fns)) { + msg <- c("`.fns` must be NULL, a function, a formula, or a list of functions/formulas.") + abort(msg) + } + + # make sure fns has names, use number to replace unnamed + if (is.null(names(fns))) { + names_fns <- seq_along(fns) + } else { + names_fns <- names(fns) + empties <- which(names_fns == "") + if (length(empties)) { + names_fns[empties] <- empties + } + } + + glue_mask <- across_glue_mask(.caller_env, + .col = rep(vars, each = length(fns)), + .fn = rep(names_fns, length(vars)) + ) + names <- vctrs::vec_as_names(glue::glue(names, .envir = glue_mask), repair = "check_unique") + + if (!inline) { + fns <- map(fns, as_function) + } + + # ensure .names argument has resulted in + if (length(names) != (length(vars) * length(fns))) { + abort( + c( + "`.names` specification must produce (number of columns * number of functions) names.", + x = paste0( + length(vars) * length(fns), " names required (", length(vars), " columns * ", length(fns), " functions)\n ", + length(names), " name(s) produced: ", paste(names, collapse = ",") + ) + ) + ) + } + + list(vars = vars, fns = fns, names = names) +} + +across_glue_mask <- function(.col, .fn, .caller_env) { + env(.caller_env, .col = .col, .fn = .fn, col = .col, fn = .fn) +} diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R index 653c1e6f25a02..ac555fafe0b50 100644 --- a/r/R/dplyr-mutate.R +++ b/r/R/dplyr-mutate.R @@ -24,7 +24,9 @@ mutate.arrow_dplyr_query <- function(.data, .before = NULL, .after = NULL) { call <- match.call() - exprs <- ensure_named_exprs(quos(...)) + + expression_list <- expand_across(.data, quos(...)) + exprs <- ensure_named_exprs(expression_list) .keep <- match.arg(.keep) .before <- enquo(.before) diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index eb2e6b02195e1..ba11700ab6216 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -321,3 +321,7 @@ split_vector_as_list <- function(vec) { vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by = 1)] list(vec1, vec2) } + +expect_across_equal <- function(actual, expected, tbl) { + expect_identical(expand_across(tbl, actual), as.list(expected)) +} diff --git a/r/tests/testthat/test-dplyr-across.R b/r/tests/testthat/test-dplyr-across.R new file mode 100644 index 0000000000000..8945c2a5f3ba8 --- /dev/null +++ b/r/tests/testthat/test-dplyr-across.R @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr, warn.conflicts = FALSE) + +test_that("expand_across correctly expands quosures", { + + # single unnamed function + expect_across_equal( + quos(across(c(dbl, dbl2), round)), + quos( + dbl = round(dbl), + dbl2 = round(dbl2) + ), + example_data + ) + + # multiple unnamed functions + expect_across_equal( + quos(across(c(dbl, dbl2), list(exp, sqrt))), + quos( + dbl_1 = exp(dbl), + dbl_2 = sqrt(dbl), + dbl2_1 = exp(dbl2), + dbl2_2 = sqrt(dbl2) + ), + example_data + ) + + # single named function + expect_across_equal( + quos(across(c(dbl, dbl2), list("fun1" = round))), + quos( + dbl_fun1 = round(dbl), + dbl2_fun1 = round(dbl2) + ), + example_data + ) + + # multiple named functions + expect_across_equal( + quos(across(c(dbl, dbl2), list("fun1" = round, "fun2" = sqrt))), + quos( + dbl_fun1 = round(dbl), + dbl_fun2 = sqrt(dbl), + dbl2_fun1 = round(dbl2), + dbl2_fun2 = sqrt(dbl2) + ), + example_data + ) + + # mix of named and unnamed functions + expect_across_equal( + quos(across(c(dbl, dbl2), list(round, "fun2" = sqrt))), + quos( + dbl_1 = round(dbl), + dbl_fun2 = sqrt(dbl), + dbl2_1 = round(dbl2), + dbl2_fun2 = sqrt(dbl2) + ), + example_data + ) + + # across() with no functions returns columns unchanged + expect_across_equal( + quos(across(starts_with("dbl"))), + quos( + dbl = dbl, + dbl2 = dbl2 + ), + example_data + ) + + # across() arguments not in default order + expect_across_equal( + quos(across(.fns = round, c(dbl, dbl2))), + quos( + dbl = round(dbl), + dbl2 = round(dbl2) + ), + example_data + ) + + # across() with no columns named + expect_across_equal( + quos(across(.fns = round)), + quos( + int = round(int), + dbl = round(dbl), + dbl2 = round(dbl2) + ), + example_data %>% select(int, dbl, dbl2) + ) + + # column selection via dynamic variable name + int <- c("dbl", "dbl2") + expect_across_equal( + quos(across(all_of(int), sqrt)), + quos( + dbl = sqrt(dbl), + dbl2 = sqrt(dbl2) + ), + example_data + ) + + # ellipses (...) are a deprecated argument + expect_error( + expand_across( + example_data, + quos(across(c(dbl, dbl2), round, digits = -1)) + ), + regexp = "`...` argument to `across()` is deprecated in dplyr and not supported in Arrow", + fixed = TRUE + ) + + # alternative ways of specifying .fns - as a list + expect_across_equal( + quos(across(1:dbl2, list(round))), + quos( + int_1 = round(int), + dbl_1 = round(dbl), + dbl2_1 = round(dbl2) + ), + example_data + ) + + # supply .fns as a one-item vector + expect_across_equal( + quos(across(1:dbl2, c(round))), + quos( + int_1 = round(int), + dbl_1 = round(dbl), + dbl2_1 = round(dbl2) + ), + example_data + ) + + # ARROW-17366: purrr-style lambda functions not yet supported + expect_error( + expand_across( + example_data, + quos(across(1:dbl2, ~ round(.x, digits = -1))) + ), + regexp = "purrr-style lambda functions as `.fns` argument to `across()` not yet supported in Arrow", + fixed = TRUE + ) + + # .names argument + expect_across_equal( + quos(across(c(dbl, dbl2), round, .names = "{.col}.{.fn}")), + quos( + dbl.1 = round(dbl), + dbl2.1 = round(dbl2) + ), + example_data + ) + + # names argument with custom text + expect_across_equal( + quos(across(c(dbl, dbl2), round, .names = "round_{.col}")), + quos( + round_dbl = round(dbl), + round_dbl2 = round(dbl2) + ), + example_data + ) + + # names argument supplied but no functions + expect_across_equal( + quos(across(starts_with("dbl"), .names = "new_{.col}")), + quos( + new_dbl = dbl, + new_dbl2 = dbl2 + ), + example_data + ) + + # .names argument and functions named + expect_across_equal( + quos(across(c(dbl, dbl2), list("my_round" = round, "my_exp" = exp), .names = "{.col}.{.fn}")), + quos( + dbl.my_round = round(dbl), + dbl.my_exp = exp(dbl), + dbl2.my_round = round(dbl2), + dbl2.my_exp = exp(dbl2) + ), + example_data + ) + + # .names argument and mix of named and unnamed functions + expect_across_equal( + quos(across(c(dbl, dbl2), list(round, "my_exp" = exp), .names = "{.col}.{.fn}")), + quos( + dbl.1 = round(dbl), + dbl.my_exp = exp(dbl), + dbl2.1 = round(dbl2), + dbl2.my_exp = exp(dbl2) + ), + example_data + ) + + # dodgy .names specification + expect_error( + expand_across( + example_data, + quos(across(c(dbl, dbl2), list(round, "my_exp" = exp), .names = "zarg")) + ), + regexp = "`.names` specification must produce (number of columns * number of functions) names.", + fixed = TRUE + ) + +}) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 66e3b4edf0d1e..f1de5c70454a2 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -279,14 +279,13 @@ test_that("dplyr::mutate's examples", { # Examples we don't support should succeed # but warn that they're pulling data into R to do so - # across and autosplicing: ARROW-11699 + # test modified from version in dplyr::mutate due to ARROW-12632 compare_dplyr_binding( .input %>% - select(name, homeworld, species) %>% - mutate(across(!name, as.factor)) %>% + select(name, height, mass) %>% + mutate(across(!name, as.character)) %>% collect(), starwars, - warning = "Expression across.*not supported in Arrow" ) # group_by then mutate @@ -589,3 +588,57 @@ test_that("mutate() and transmute() with namespaced functions", { tbl ) }) + +test_that("Can use across() within mutate()", { + + # expressions work in the right order + compare_dplyr_binding( + .input %>% + mutate( + dbl2 = dbl * 2, + across(c(dbl, dbl2), round), + int2 = int * 2, + dbl = dbl + 3 + ) %>% + collect(), + example_data + ) + + # this is valid is neither R nor Arrow + expect_error( + expect_warning( + compare_dplyr_binding( + .input %>% + arrow_table() %>% + mutate(across(c(dbl, dbl2), list("fun1" = round(sqrt(dbl))))) %>% + collect(), + example_data, + warning = TRUE + ) + ) + ) + + # ARROW-12778 - `where()` is not yet supported + expect_error( + compare_dplyr_binding( + .input %>% + mutate(across(where(is.double))) %>% + collect(), + example_data + ), + "Unsupported selection helper" + ) + + # gives the right error with window functions + expect_warning( + arrow_table(example_data) %>% + mutate( + x = int + 2, + across(c("int", "dbl"), list(mean = mean, sd = sd, round)), + exp(dbl2) + ) %>% + collect(), + "window functions not currently supported in Arrow; pulling data into R", + fixed = TRUE + ) +})