Skip to content

Commit

Permalink
ARROW-14029: [R] Repair map_batches()
Browse files Browse the repository at this point in the history
Updating `map_batches()` function to use `RecordBatchReader` instead of `Scanner$ScanBatches()` so that only one record batch is in memory at a time.

~As part of this, I refactored `do_exec_plan` to always return a RBR instead of a materialized Table.~ I don't think I can refactor `do_exec_plan` to always return a RBR until we get `arrange`, `head`, and `tail` operations to work outside of a sink node. See: https://issues.apache.org/jira/browse/ARROW-15271

Closes #11894 from wjones127/ARROW-14029-r-map-batches

Authored-by: Will Jones <willjones127@gmail.com>
Signed-off-by: Jonathan Keane <jkeane@gmail.com>
  • Loading branch information
wjones127 authored and jonkeane committed Jan 7, 2022
1 parent e64480d commit f054440
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 12 deletions.
39 changes: 29 additions & 10 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,36 @@ ScanTask <- R6Class("ScanTask",
#' `data.frame`? Default `TRUE`
#' @export
map_batches <- function(X, FUN, ..., .data.frame = TRUE) {
if (.data.frame) {
lapply <- map_dfr
}
scanner <- Scanner$create(ensure_group_vars(X))
# TODO: ARROW-15271 possibly refactor do_exec_plan to return a RecordBatchReader
plan <- ExecPlan$create()
final_node <- plan$Build(X)
reader <- plan$Run(final_node)
FUN <- as_mapper(FUN)
lapply(scanner$ScanBatches(), function(batch) {
# TODO: wrap batch in arrow_dplyr_query with X$selected_columns,
# X$temp_columns, and X$group_by_vars
# if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE
FUN(batch, ...)
})

# TODO: wrap batch in arrow_dplyr_query with X$selected_columns,
# X$temp_columns, and X$group_by_vars
# if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE
batch <- reader$read_next_batch()
res <- vector("list", 1024)
i <- 0L
while (!is.null(batch)) {
i <- i + 1L
res[[i]] <- FUN(batch, ...)
batch <- reader$read_next_batch()
}

# Trim list back
if (i < length(res)) {
res <- res[seq_len(i)]
}

if (.data.frame & inherits(res[[1]], "arrow_dplyr_query")) {
res <- dplyr::bind_rows(map(res, collect))
} else if (.data.frame) {
res <- dplyr::bind_rows(map(res, as.data.frame))
}

res
}

#' @usage NULL
Expand Down
27 changes: 25 additions & 2 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -453,15 +453,38 @@ test_that("Creating UnionDataset", {
})

test_that("map_batches", {
skip("map_batches() is broken (ARROW-14029)")
ds <- open_dataset(dataset_dir, partitioning = "part")

# summarize returns arrow_dplyr_query, which gets collected into a tibble
expect_equal(
ds %>%
filter(int > 5) %>%
select(int, lgl) %>%
map_batches(~ summarize(., min_int = min(int))),
map_batches(~ summarize(., min_int = min(int))) %>%
arrange(min_int),
tibble(min_int = c(6L, 101L))
)

# $num_rows returns integer vector
expect_equal(
ds %>%
filter(int > 5) %>%
select(int, lgl) %>%
map_batches(~ .$num_rows, .data.frame = FALSE) %>%
unlist() %>% # Returns list because .data.frame is FALSE
sort(),
c(5, 10)
)

# $Take returns RecordBatch, which gets binded into a tibble
expect_equal(
ds %>%
filter(int > 5) %>%
select(int, lgl) %>%
map_batches(~ .$Take(0)) %>%
arrange(int),
tibble(int = c(6, 101), lgl = c(TRUE, TRUE))
)
})

test_that("partitioning = NULL to ignore partition information (but why?)", {
Expand Down
60 changes: 60 additions & 0 deletions r/vignettes/dataset.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,66 @@ rows match the filter. Relatedly, since Parquet files contain row groups with
statistics on the data within, there may be entire chunks of data you can
avoid scanning because they have no rows where `total_amount > 100`.

### Processing data in batches

Sometimes you want to run R code on the entire dataset, but that dataset is much
larger than memory. You can use `map_batches` on a dataset query to process
it batch-by-batch.

**Note**: `map_batches` is experimental and not recommended for production use.

As an example, to randomly sample a dataset, use `map_batches` to sample a
percentage of rows from each batch:

```{r, eval = file.exists("nyc-taxi")}
sampled_data <- ds %>%
filter(year == 2015) %>%
select(tip_amount, total_amount, passenger_count) %>%
map_batches(~ sample_frac(as.data.frame(.), 1e-4)) %>%
mutate(tip_pct = tip_amount / total_amount)
str(sampled_data)
```

```{r, echo = FALSE, eval = !file.exists("nyc-taxi")}
cat("
'data.frame': 15603 obs. of 4 variables:
$ tip_amount : num 0 0 1.55 1.45 5.2 ...
$ total_amount : num 5.8 16.3 7.85 8.75 26 ...
$ passenger_count: int 1 1 1 1 1 6 5 1 2 1 ...
$ tip_pct : num 0 0 0.197 0.166 0.2 ...
")
```

This function can also be used to aggregate summary statistics over a dataset by
computing partial results for each batch and then aggregating those partial
results. Extending the example above, you could fit a model to the sample data
and then use `map_batches` to compute the MSE on the full dataset.

```{r, eval = file.exists("nyc-taxi")}
model <- lm(tip_pct ~ total_amount + passenger_count, data = sampled_data)
ds %>%
filter(year == 2015) %>%
select(tip_amount, total_amount, passenger_count) %>%
mutate(tip_pct = tip_amount / total_amount) %>%
map_batches(function(batch) {
batch %>%
as.data.frame() %>%
mutate(pred_tip_pct = predict(model, newdata = .)) %>%
filter(!is.nan(tip_pct)) %>%
summarize(sse_partial = sum((pred_tip_pct - tip_pct)^2), n_partial = n())
}) %>%
summarize(mse = sum(sse_partial) / sum(n_partial)) %>%
pull(mse)
```

```{r, echo = FALSE, eval = !file.exists("nyc-taxi")}
cat("
[1] 0.1304284
")
```

## More dataset options

There are a few ways you can control the Dataset creation to adapt to special use cases.
Expand Down

0 comments on commit f054440

Please sign in to comment.