Skip to content

Commit

Permalink
Change behavior of pull to compute instead of collect
Browse files Browse the repository at this point in the history
  • Loading branch information
amoeba committed Oct 5, 2022
1 parent 69249c3 commit 1923bd2
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 21 deletions.
1 change: 1 addition & 0 deletions r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ export(null)
export(num_range)
export(one_of)
export(open_dataset)
export(pull.ArrowTabular)
export(read_csv_arrow)
export(read_delim_arrow)
export(read_feather)
Expand Down
5 changes: 5 additions & 0 deletions r/R/arrow-tabular.R
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,8 @@ na.omit.ArrowTabular <- function(object, ...) {

#' @export
na.exclude.ArrowTabular <- na.omit.ArrowTabular

#' @export
pull.ArrowTabular <- function(x, var = -1) {
x[[vars_pull(names(x), !!enquo(var))]]
}
4 changes: 2 additions & 2 deletions r/R/dplyr-collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ pull.arrow_dplyr_query <- function(.data, var = -1) {
.data <- as_adq(.data)
var <- vars_pull(names(.data), !!enquo(var))
.data$selected_columns <- set_names(.data$selected_columns[var], var)
dplyr::collect(.data)[[1]]
dplyr::compute(.data)[[1]]
}
pull.Dataset <- pull.ArrowTabular <- pull.RecordBatchReader <- pull.arrow_dplyr_query
pull.Dataset <- pull.RecordBatchReader <- pull.arrow_dplyr_query

restore_dplyr_features <- function(df, query) {
# An arrow_dplyr_query holds some attributes that Arrow doesn't know about
Expand Down
4 changes: 3 additions & 1 deletion r/tests/testthat/test-dataset-write.R
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,8 @@ test_that("Dataset min_rows_per_group", {

row_group_sizes <- ds %>%
map_batches(~ record_batch(nrows = .$num_rows)) %>%
pull(nrows)
pull(nrows) %>%
as.vector()
index <- 1

# We expect there to be 3 row groups since 11/5 = 2.2 and 11/4 = 2.75
Expand Down Expand Up @@ -778,6 +779,7 @@ test_that("Dataset write max rows per group", {
row_group_sizes <- ds %>%
map_batches(~ record_batch(nrows = .$num_rows)) %>%
pull(nrows) %>%
as.vector() %>%
sort()

expect_equal(row_group_sizes, c(12, 18))
Expand Down
41 changes: 29 additions & 12 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ test_that("IPC/Feather format data", {

# Collecting virtual partition column works
expect_equal(
ds %>% arrange(part) %>% pull(part),
ds %>% arrange(part) %>% pull(part) %>% as.vector(),
c(rep(3, 10), rep(4, 10))
)
})
Expand Down Expand Up @@ -306,7 +306,7 @@ test_that("Simple interface for datasets", {

# Collecting virtual partition column works
expect_equal(
ds %>% arrange(part) %>% pull(part),
ds %>% arrange(part) %>% pull(part) %>% as.vector(),
c(rep(1, 10), rep(2, 10))
)
})
Expand Down Expand Up @@ -625,8 +625,16 @@ test_that("scalar aggregates with many batches (ARROW-16904)", {
ds <- open_dataset(tf)
replicate(100, ds %>% summarize(min(x)) %>% pull())

expect_true(all(replicate(100, ds %>% summarize(min(x)) %>% pull()) == 1))
expect_true(all(replicate(100, ds %>% summarize(max(x)) %>% pull()) == 100))
expect_true(
all(
replicate(100, ds %>% summarize(min(x)) %>% pull() %>% as.vector()) == 1
)
)
expect_true(
all(
replicate(100, ds %>% summarize(max(x)) %>% pull() %>% as.vector()) == 100
)
)
})

test_that("map_batches", {
Expand All @@ -650,6 +658,7 @@ test_that("map_batches", {
select(int, lgl) %>%
map_batches(~ record_batch(nrows = .$num_rows)) %>%
pull(nrows) %>%
as.vector() %>%
sort(),
c(5, 10)
)
Expand Down Expand Up @@ -1170,7 +1179,8 @@ test_that("FileSystemFactoryOptions with DirectoryPartitioning", {
expect_equal(
ds %>%
arrange(cyl) %>%
pull(cyl),
pull(cyl) %>%
as.vector(),
sort(mtcars$cyl)
)

Expand All @@ -1188,7 +1198,8 @@ test_that("FileSystemFactoryOptions with DirectoryPartitioning", {
expect_equal(
ds %>%
arrange(cyl) %>%
pull(cyl),
pull(cyl) %>%
as.vector(),
sort(mtcars$cyl)
)

Expand All @@ -1204,7 +1215,8 @@ test_that("FileSystemFactoryOptions with DirectoryPartitioning", {
expect_equal(
ds %>%
arrange(cyl) %>%
pull(cyl),
pull(cyl) %>%
as.vector(),
sort(mtcars$cyl)
)

Expand All @@ -1222,7 +1234,8 @@ test_that("FileSystemFactoryOptions with DirectoryPartitioning", {
expect_equal(
ds %>%
arrange(cyl) %>%
pull(cyl),
pull(cyl) %>%
as.vector(),
sort(mtcars$cyl)
)

Expand Down Expand Up @@ -1256,7 +1269,8 @@ test_that("FileSystemFactoryOptions with HivePartitioning", {
expect_equal(
ds %>%
arrange(cyl) %>%
pull(cyl),
pull(cyl) %>%
as.vector(),
sort(mtcars$cyl)
)

Expand All @@ -1272,7 +1286,8 @@ test_that("FileSystemFactoryOptions with HivePartitioning", {
expect_equal(
ds %>%
arrange(cyl) %>%
pull(cyl),
pull(cyl) %>%
as.vector(),
sort(mtcars$cyl)
)

Expand All @@ -1286,7 +1301,8 @@ test_that("FileSystemFactoryOptions with HivePartitioning", {
expect_equal(
ds %>%
arrange(cyl) %>%
pull(cyl),
pull(cyl) %>%
as.vector(),
sort(mtcars$cyl)
)

Expand All @@ -1302,7 +1318,8 @@ test_that("FileSystemFactoryOptions with HivePartitioning", {
expect_equal(
ds %>%
arrange(cyl) %>%
pull(cyl),
pull(cyl) %>%
as.vector(),
sort(mtcars$cyl)
)
})
Expand Down
4 changes: 3 additions & 1 deletion r/tests/testthat/test-dplyr-arrange.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ test_that("arrange() on integer, double, and character columns", {
.input %>%
group_by(grp) %>%
arrange(.by_group = TRUE) %>%
pull(grp),
ungroup() %>%
pull(grp) %>%
as.vector(),
tbl
)
compare_dplyr_binding(
Expand Down
3 changes: 2 additions & 1 deletion r/tests/testthat/test-dplyr-funcs-datetime.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ test_that("strptime", {
mutate(
x = strptime(x, format = "%m-%d-%Y")
) %>%
pull(),
pull() %>%
as.vector(),
# R's strptime returns POSIXlt (list type)
as.POSIXct(tstamp),
ignore_attr = "tzone"
Expand Down
9 changes: 5 additions & 4 deletions r/tests/testthat/test-dplyr-query.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,23 @@ See $.data for the source Arrow object',

test_that("pull", {
compare_dplyr_binding(
.input %>% pull(),
.input %>% pull() %>% as.vector(),
tbl
)
compare_dplyr_binding(
.input %>% pull(1),
.input %>% pull(1) %>% as.vector(),
tbl
)
compare_dplyr_binding(
.input %>% pull(chr),
.input %>% pull(chr) %>% as.vector(),
tbl
)
compare_dplyr_binding(
.input %>%
filter(int > 4) %>%
rename(strng = chr) %>%
pull(strng),
pull(strng) %>%
as.vector(),
tbl
)
})
Expand Down

0 comments on commit 1923bd2

Please sign in to comment.