Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-17387: [R] Implement dplyr::across() inside filter() #14281

Merged
merged 15 commits into from Oct 11, 2022
Merged
3 changes: 3 additions & 0 deletions r/NAMESPACE
Expand Up @@ -405,6 +405,7 @@ importFrom(purrr,map_dbl)
importFrom(purrr,map_dfr)
importFrom(purrr,map_int)
importFrom(purrr,map_lgl)
importFrom(purrr,reduce)
importFrom(rlang,"%||%")
importFrom(rlang,":=")
importFrom(rlang,.data)
Expand All @@ -426,6 +427,7 @@ importFrom(rlang,env_bind)
importFrom(rlang,eval_tidy)
importFrom(rlang,exec)
importFrom(rlang,expr)
importFrom(rlang,expr_text)
importFrom(rlang,f_env)
importFrom(rlang,f_rhs)
importFrom(rlang,is_bare_character)
Expand All @@ -443,6 +445,7 @@ importFrom(rlang,list2)
importFrom(rlang,new_data_mask)
importFrom(rlang,new_environment)
importFrom(rlang,new_quosure)
importFrom(rlang,new_quosures)
importFrom(rlang,parse_expr)
importFrom(rlang,quo)
importFrom(rlang,quo_get_env)
Expand Down
4 changes: 3 additions & 1 deletion r/R/arrow-package.R
Expand Up @@ -17,14 +17,16 @@

#' @importFrom stats quantile median na.omit na.exclude na.pass na.fail
#' @importFrom R6 R6Class
#' @importFrom purrr as_mapper map map2 map_chr map2_chr map_dbl map_dfr map_int map_lgl keep imap imap_chr flatten
#' @importFrom purrr as_mapper map map2 map_chr map2_chr map_dbl map_dfr map_int map_lgl keep imap imap_chr
#' @importFrom purrr flatten reduce
#' @importFrom assertthat assert_that is.string
#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos quo
#' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind set_names exec
#' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr .data seq2 is_interactive
#' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure
#' @importFrom rlang is_list call2 is_empty as_function as_label arg_match is_symbol is_call call_args
#' @importFrom rlang quo_set_env quo_get_env is_formula quo_is_call f_rhs parse_expr f_env new_quosure
#' @importFrom rlang new_quosures expr_text
#' @importFrom tidyselect vars_pull vars_rename vars_select eval_select
#' @importFrom glue glue
#' @useDynLib arrow, .registration = TRUE
Expand Down
35 changes: 32 additions & 3 deletions r/R/dplyr-across.R
Expand Up @@ -23,7 +23,7 @@ expand_across <- function(.data, quos_in) {
quo_expr <- quo_get_expr(quo_in[[1]])
quo_env <- quo_get_env(quo_in[[1]])

if (is_call(quo_expr, "across")) {
if (is_call(quo_expr, c("across", "if_any", "if_all"))) {
nealrichardson marked this conversation as resolved.
Show resolved Hide resolved
new_quos <- list()

across_call <- match.call(
Expand Down Expand Up @@ -58,9 +58,38 @@ expand_across <- function(.data, quos_in) {
} else {
quos_out <- append(quos_out, quo_in)
}

if (is_call(quo_expr, "if_any")) {
quos_out <- append(list(), purrr::reduce(quos_out, combine_if, op = "|", envir = quo_get_env(quos_out[[1]])))
}

if (is_call(quo_expr, "if_all")) {
quos_out <- append(list(), purrr::reduce(quos_out, combine_if, op = "&", envir = quo_get_env(quos_out[[1]])))
}
}

quos_out
new_quosures(quos_out)
}

# takes multiple expressions and combines them with & or |
combine_if <- function(lhs, rhs, op, envir) {
expr_text <- paste(
expr_text(quo_get_expr(lhs)),
expr_text(quo_get_expr(rhs)),
sep = paste0(" ", op, " ")
)

expr <- parse_expr(expr_text)

new_quosure(expr, envir)
}

if_any <- function(.cols = everything(), .fns = NULL, ..., .names = NULL) {

}

if_all <- function(.cols = everything(), .fns = NULL, ..., .names = NULL) {

thisisnic marked this conversation as resolved.
Show resolved Hide resolved
}

# given a named list of functions and column names, create a list of new quosures
Expand Down Expand Up @@ -157,7 +186,7 @@ across_glue_mask <- function(.col, .fn, .caller_env) {
env(.caller_env, .col = .col, .fn = .fn, col = .col, fn = .fn)
}

# Substitutes instances of `.` and `.x` with the variable in question
# Substitutes instances of "." and ".x" with `var`
as_across_fn_call <- function(fn, var, quo_env) {
if (is_formula(fn, lhs = FALSE)) {
expr <- f_rhs(fn)
Expand Down
2 changes: 1 addition & 1 deletion r/R/dplyr-filter.R
Expand Up @@ -20,7 +20,7 @@

filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) {
# TODO something with the .preserve argument
filts <- quos(...)
filts <- expand_across(.data, quos(...))
if (length(filts) == 0) {
# Nothing to do
return(.data)
Expand Down
10 changes: 7 additions & 3 deletions r/data-raw/docgen.R
Expand Up @@ -128,11 +128,15 @@ docs <- arrow:::.cache$docs

# across() is handled by manipulating the quosures, not by nse_funcs
docs[["dplyr::across"]] <- c(
# TODO(ARROW-17387): do filter
"not yet supported inside `filter()`;",
# TODO(ARROW-17384): implement where
"and use of `where()` selection helper not yet supported"
"Use of `where()` selection helper not yet supported"
)

# if_any() and if_all() are used instead of across() in filter()
# they are both handled by manipulating the quosures, not by nse_funcs
docs[["dplyr::if_any"]] <- character(0)
docs[["dplyr::if_all"]] <- character(0)

# desc() is a special helper handled inside of arrange()
docs[["dplyr::desc"]] <- character(0)

Expand Down
2 changes: 1 addition & 1 deletion r/tests/testthat/helper-expectation.R
Expand Up @@ -323,5 +323,5 @@ split_vector_as_list <- function(vec) {
}

expect_across_equal <- function(across_expr, expected, tbl) {
expect_identical(expand_across(as_adq(tbl), across_expr), as.list(expected))
expect_identical(expand_across(as_adq(tbl), across_expr), new_quosures(expected))
}
22 changes: 22 additions & 0 deletions r/tests/testthat/test-dplyr-across.R
Expand Up @@ -278,3 +278,25 @@ test_that("ARROW-14071 - function(x)-style lambda functions are not supported",
regexp = "Anonymous functions are not yet supported in Arrow"
)
})

test_that("if_all() and if_any() are supported", {

expect_across_equal(
quos(if_any(everything(), ~is.na(.x))),
quos(is.na(int) | is.na(dbl) | is.na(dbl2) | is.na(lgl) | is.na(false) | is.na(chr) | is.na(fct)),
example_data
)

expect_across_equal(
quos(if_all(everything(), ~is.na(.x))),
quos(is.na(int) & is.na(dbl) & is.na(dbl2) & is.na(lgl) & is.na(false) & is.na(chr) & is.na(fct)),
example_data
)

expect_across_equal(
quos(if_all(everything(), ~is.na(.x))),
quos(is.na(int) & is.na(dbl) & is.na(dbl2) & is.na(lgl) & is.na(false) & is.na(chr) & is.na(fct)),
example_data
)
thisisnic marked this conversation as resolved.
Show resolved Hide resolved

})
22 changes: 22 additions & 0 deletions r/tests/testthat/test-dplyr-filter.R
Expand Up @@ -417,3 +417,25 @@ test_that("filter() with namespaced functions", {
tbl
)
})

test_that("filter() with across()", {

compare_dplyr_binding(
.input %>%
filter(if_any(ends_with("l"), ~ is.na(.))) %>%
collect(),
tbl
)

compare_dplyr_binding(
.input %>%
filter(
false == FALSE,
if_all(everything(), ~ !is.na(.)),
int > 2
) %>%
collect(),
tbl
)

})