Skip to content

Commit

Permalink
ARROW-11699: [R] Implement dplyr::across() for mutate()
Browse files Browse the repository at this point in the history
This PR introduces a partial implementation for `dplyr::across()` when called within `dplyr::mutate()`.

``` r
arrow_table(iris) %>%
  mutate(across(starts_with("Sepal"), list(round, sqrt)))
#> Table (query)
#> Sepal.Length: double
#> Sepal.Width: double
#> Petal.Length: double
#> Petal.Width: double
#> Species: dictionary<values=string, indices=int8>
#> Sepal.Length_1: double (round(Sepal.Length, {ndigits=0, round_mode=HALF_TO_EVEN}))
#> Sepal.Length_2: double (sqrt_checked(Sepal.Length))
#> Sepal.Width_1: double (round(Sepal.Width, {ndigits=0, round_mode=HALF_TO_EVEN}))
#> Sepal.Width_2: double (sqrt_checked(Sepal.Width))
#>
#> See $.data for the source Arrow object
```

I've opened a number of follow-up tickets for the tasks needed to be done to provide a more complete implementation:
* [ARROW-17362: [R] Implement dplyr::across() inside summarise()](https://issues.apache.org/jira/browse/ARROW-17362)
* [ARROW-17387: [R] Implement dplyr::across() inside filter()](https://issues.apache.org/jira/browse/ARROW-17387)
* ~[ARROW-17364: [R] Implement .names argument inside across()](https://issues.apache.org/jira/browse/ARROW-17364)~ (now done in this PR, will close it once this is merged)
* [ARROW-17366: [R] Support purrr-style lambda functions in .fns argument to across()](https://issues.apache.org/jira/browse/ARROW-17366)

Closes #13786 from thisisnic/ARROW-11699_across

Authored-by: Nic Crane <thisisnic@gmail.com>
Signed-off-by: Nic Crane <thisisnic@gmail.com>
  • Loading branch information
thisisnic committed Sep 1, 2022
1 parent fe6e902 commit d5f80cb
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 6 deletions.
2 changes: 2 additions & 0 deletions r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Biarch: true
Imports:
assertthat,
bit64 (>= 0.9-7),
glue,
methods,
purrr,
R6,
Expand Down Expand Up @@ -91,6 +92,7 @@ Collate:
'dataset-scan.R'
'dataset-write.R'
'dictionary.R'
'dplyr-across.R'
'dplyr-arrange.R'
'dplyr-collect.R'
'dplyr-count.R'
Expand Down
7 changes: 7 additions & 0 deletions r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ importFrom(assertthat,assert_that)
importFrom(assertthat,is.string)
importFrom(bit64,print.integer64)
importFrom(bit64,str.integer64)
importFrom(glue,glue)
importFrom(methods,as)
importFrom(purrr,as_mapper)
importFrom(purrr,flatten)
Expand All @@ -413,6 +414,7 @@ importFrom(rlang,as_function)
importFrom(rlang,as_label)
importFrom(rlang,as_quosure)
importFrom(rlang,call2)
importFrom(rlang,call_args)
importFrom(rlang,caller_env)
importFrom(rlang,dots_n)
importFrom(rlang,enexpr)
Expand All @@ -425,20 +427,25 @@ importFrom(rlang,eval_tidy)
importFrom(rlang,exec)
importFrom(rlang,expr)
importFrom(rlang,is_bare_character)
importFrom(rlang,is_call)
importFrom(rlang,is_character)
importFrom(rlang,is_empty)
importFrom(rlang,is_false)
importFrom(rlang,is_formula)
importFrom(rlang,is_integerish)
importFrom(rlang,is_interactive)
importFrom(rlang,is_list)
importFrom(rlang,is_quosure)
importFrom(rlang,is_symbol)
importFrom(rlang,list2)
importFrom(rlang,new_data_mask)
importFrom(rlang,new_environment)
importFrom(rlang,quo_get_env)
importFrom(rlang,quo_get_expr)
importFrom(rlang,quo_is_call)
importFrom(rlang,quo_is_null)
importFrom(rlang,quo_name)
importFrom(rlang,quo_set_env)
importFrom(rlang,quo_set_expr)
importFrom(rlang,quos)
importFrom(rlang,seq2)
Expand Down
4 changes: 3 additions & 1 deletion r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
#' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind set_names exec
#' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr .data seq2 is_interactive
#' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure
#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match
#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match is_symbol is_call call_args
#' @importFrom rlang quo_set_env quo_get_env is_formula quo_is_call
#' @importFrom tidyselect vars_pull vars_rename vars_select eval_select
#' @importFrom glue glue
#' @useDynLib arrow, .registration = TRUE
#' @keywords internal
"_PACKAGE"
Expand Down
177 changes: 177 additions & 0 deletions r/R/dplyr-across.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

expand_across <- function(.data, quos_in) {
quos_out <- list()
# retrieve items using their values to preserve naming of quos other than across
for (quo_i in seq_along(quos_in)) {
quo_in <- quos_in[quo_i]
quo_expr <- quo_get_expr(quo_in[[1]])
quo_env <- quo_get_env(quo_in[[1]])

if (is_call(quo_expr, "across")) {
new_quos <- list()

across_call <- match.call(
definition = dplyr::across,
call = quo_expr,
expand.dots = FALSE,
envir = quo_env
)

if (!all(names(across_call[-1]) %in% c(".cols", ".fns", ".names"))) {
abort("`...` argument to `across()` is deprecated in dplyr and not supported in Arrow")
}

if (!is.null(across_call[[".cols"]])) {
cols <- across_call[[".cols"]]
} else {
cols <- quote(everything())
}

setup <- across_setup(
cols = !!as_quosure(cols, quo_env),
fns = across_call[[".fns"]],
names = across_call[[".names"]],
.caller_env = quo_env,
mask = .data,
inline = TRUE
)

if (!is_list(setup$fns) && !is.null(setup$fns) && as.character(setup$fns)[[1]] == "~") {
abort(
paste(
"purrr-style lambda functions as `.fns` argument to `across()`",
"not yet supported in Arrow"
)
)
}

new_quos <- quosures_from_setup(setup, quo_env)

quos_out <- append(quos_out, new_quos)
} else {
quos_out <- append(quos_out, quo_in)
}
}

quos_out
}

# given a named list of functions and column names, create a list of new quosures
quosures_from_setup <- function(setup, quo_env) {
if (!is.null(setup$fns)) {
func_list_full <- rep(setup$fns, length(setup$vars))
cols_list_full <- rep(setup$vars, each = length(setup$fns))

# get new quosures
new_quo_list <- map2(
func_list_full, cols_list_full,
~ quo(!!call2(.x, sym(.y)))
)
} else {
# if there's no functions, just map to variables themselves
new_quo_list <- map(
setup$vars,
~ quo(!!sym(.x))
)
}

quosures <- set_names(new_quo_list, setup$names)
map(quosures, ~ quo_set_env(.x, quo_env))
}

across_setup <- function(cols, fns, names, .caller_env, mask, inline = FALSE) {
cols <- enquo(cols)

vars <- names(dplyr::select(mask, !!cols))

if (is.null(fns)) {
if (!is.null(names)) {
glue_mask <- across_glue_mask(.caller_env, .col = vars, .fn = "1")
names <- vctrs::vec_as_names(glue::glue(names, .envir = glue_mask), repair = "check_unique")
} else {
names <- vars
}

value <- list(vars = vars, fns = fns, names = names)
return(value)
}

# apply `.names` smart default
if (is.function(fns) || is_formula(fns) || is.name(fns)) {
names <- names %||% "{.col}"
fns <- list("1" = fns)
} else {
names <- names %||% "{.col}_{.fn}"
fns <- call_args(fns)
}

if (any(map_lgl(fns, is_formula))) {
abort(
paste(
"purrr-style lambda functions as `.fns` argument to `across()`",
"not yet supported in Arrow"
)
)
}

if (!is.list(fns)) {
msg <- c("`.fns` must be NULL, a function, a formula, or a list of functions/formulas.")
abort(msg)
}

# make sure fns has names, use number to replace unnamed
if (is.null(names(fns))) {
names_fns <- seq_along(fns)
} else {
names_fns <- names(fns)
empties <- which(names_fns == "")
if (length(empties)) {
names_fns[empties] <- empties
}
}

glue_mask <- across_glue_mask(.caller_env,
.col = rep(vars, each = length(fns)),
.fn = rep(names_fns, length(vars))
)
names <- vctrs::vec_as_names(glue::glue(names, .envir = glue_mask), repair = "check_unique")

if (!inline) {
fns <- map(fns, as_function)
}

# ensure .names argument has resulted in
if (length(names) != (length(vars) * length(fns))) {
abort(
c(
"`.names` specification must produce (number of columns * number of functions) names.",
x = paste0(
length(vars) * length(fns), " names required (", length(vars), " columns * ", length(fns), " functions)\n ",
length(names), " name(s) produced: ", paste(names, collapse = ",")
)
)
)
}

list(vars = vars, fns = fns, names = names)
}

across_glue_mask <- function(.col, .fn, .caller_env) {
env(.caller_env, .col = .col, .fn = .fn, col = .col, fn = .fn)
}
4 changes: 3 additions & 1 deletion r/R/dplyr-mutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ mutate.arrow_dplyr_query <- function(.data,
.before = NULL,
.after = NULL) {
call <- match.call()
exprs <- ensure_named_exprs(quos(...))

expression_list <- expand_across(.data, quos(...))
exprs <- ensure_named_exprs(expression_list)

.keep <- match.arg(.keep)
.before <- enquo(.before)
Expand Down
4 changes: 4 additions & 0 deletions r/tests/testthat/helper-expectation.R
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,7 @@ split_vector_as_list <- function(vec) {
vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by = 1)]
list(vec1, vec2)
}

expect_across_equal <- function(actual, expected, tbl) {
expect_identical(expand_across(tbl, actual), as.list(expected))
}
Loading

0 comments on commit d5f80cb

Please sign in to comment.