Skip to content

Commit

Permalink
ARROW-15622: [R] Implement union_all and union for arrow_dplyr_query
Browse files Browse the repository at this point in the history
This PR adds support for `dplyr::union` and `dplyr::union_all`. Not sure why, but I find I must use the fully qualified name `dplyr::union` or else will get an error.

Closes #13090 from wjones127/ARROW-15622-union-all

Lead-authored-by: Will Jones <willjones127@gmail.com>
Co-authored-by: Neal Richardson <neal.p.richardson@gmail.com>
Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
  • Loading branch information
wjones127 and nealrichardson committed May 24, 2022
1 parent 6576aa0 commit d889ade
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 2 deletions.
1 change: 1 addition & 0 deletions r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Collate:
'dplyr-mutate.R'
'dplyr-select.R'
'dplyr-summarize.R'
'dplyr-union.R'
'record-batch.R'
'table.R'
'dplyr.R'
Expand Down
2 changes: 1 addition & 1 deletion r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"group_vars", "group_by_drop_default", "ungroup", "mutate", "transmute",
"arrange", "rename", "pull", "relocate", "compute", "collapse",
"distinct", "left_join", "right_join", "inner_join", "full_join",
"semi_join", "anti_join", "count", "tally", "rename_with"
"semi_join", "anti_join", "count", "tally", "rename_with", "union", "union_all"
)
)
for (cl in c("Dataset", "ArrowTabular", "RecordBatchReader", "arrow_dplyr_query")) {
Expand Down
5 changes: 4 additions & 1 deletion r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions r/R/dplyr-union.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# The following S3 methods are registered on load if dplyr is present

union.arrow_dplyr_query <- function(x, y, ...) {
x <- as_adq(x)
y <- as_adq(y)

distinct(union_all(x, y))
}

union.Dataset <- union.ArrowTabular <- union.RecordBatchReader <- union.arrow_dplyr_query

union_all.arrow_dplyr_query <- function(x, y, ...) {
x <- as_adq(x)
y <- as_adq(y)

x$union_all <- list(right_data = y)
collapse.arrow_dplyr_query(x)
}

union_all.Dataset <- union_all.ArrowTabular <- union_all.RecordBatchReader <- union_all.arrow_dplyr_query
7 changes: 7 additions & 0 deletions r/R/query-engine.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ ExecPlan <- R6Class("ExecPlan",
right_suffix = .data$join$suffix[[2]]
)
}

if (!is.null(.data$union_all)) {
node <- node$UnionAll(self$Build(.data$union_all$right_data))
}
}

# Apply sorting: this is currently not an ExecNode itself, it is a
Expand Down Expand Up @@ -271,6 +275,9 @@ ExecNode <- R6Class("ExecNode",
output_suffix_for_right = right_suffix
)
)
},
UnionAll = function(right_node) {
self$preserve_sort(ExecNode_Union(self, right_node))
}
),
active = list(
Expand Down
10 changes: 10 additions & 0 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions r/src/compute-exec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,13 @@ std::shared_ptr<compute::ExecNode> ExecNode_Join(
std::move(output_suffix_for_left), std::move(output_suffix_for_right)});
}

// [[arrow::export]]
std::shared_ptr<compute::ExecNode> ExecNode_Union(
const std::shared_ptr<compute::ExecNode>& input,
const std::shared_ptr<compute::ExecNode>& right_data) {
return MakeExecNodeOrStop("union", input->plan(), {input.get(), right_data.get()}, {});
}

// [[arrow::export]]
std::shared_ptr<compute::ExecNode> ExecNode_SourceNode(
const std::shared_ptr<compute::ExecPlan>& plan,
Expand Down
74 changes: 74 additions & 0 deletions r/tests/testthat/test-dplyr-union.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

skip_if(on_old_windows())

library(dplyr, warn.conflicts = FALSE)

withr::local_options(list(arrow.summarise.sort = FALSE))

test_that("union_all", {
compare_dplyr_binding(
.input %>%
union_all(example_data) %>%
collect(),
example_data
)

test_table <- arrow_table(x = 1:10)

# Union with empty table produces same dataset
expect_equal(
test_table %>%
union_all(test_table$Slice(0, 0)) %>%
compute(),
test_table
)

expect_error(
test_table %>%
union_all(arrow_table(y = 1:10)) %>%
collect(),
regex = "input schemas must all match"
)
})

test_that("union", {
compare_dplyr_binding(
.input %>%
dplyr::union(example_data) %>%
collect(),
example_data
)

test_table <- arrow_table(x = 1:10)

# Union with empty table produces same dataset
expect_equal(
test_table %>%
dplyr::union(test_table$Slice(0, 0)) %>%
compute(),
test_table
)

expect_error(
test_table %>%
dplyr::union(arrow_table(y = 1:10)) %>%
collect(),
regex = "input schemas must all match"
)
})

0 comments on commit d889ade

Please sign in to comment.