Skip to content

Commit

Permalink
ARROW-11704: [R] Wire up dplyr::mutate() for datasets
Browse files Browse the repository at this point in the history
Closes #9586 from nealrichardson/r-dataset-projection

Authored-by: Neal Richardson <neal.p.richardson@gmail.com>
Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
  • Loading branch information
nealrichardson committed Mar 5, 2021
1 parent e2c7d95 commit 906331c
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 23 deletions.
2 changes: 1 addition & 1 deletion r/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

## dplyr methods

* `dplyr::mutate()` on Arrow `Table` and `RecordBatch` is now supported in Arrow for many applications. Where not yet supported, the implementation falls back to pulling data into an R `data.frame` first.
* `dplyr::mutate()` is now supported in Arrow for many applications. For queries on `Table` and `RecordBatch` that are not yet supported in Arrow, the implementation falls back to pulling data into an R `data.frame` first, as in the previous release. For queries on `Dataset`, it raises an error if the feature is not implemented.
* String functions `nchar()`, `tolower()`, and `toupper()`, along with their `stringr` spellings `str_length()`, `str_to_lower()`, and `str_to_upper()`, are supported in Arrow `dplyr` calls. `str_trim()` is also supported.

## Other improvements
Expand Down
8 changes: 6 additions & 2 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 12 additions & 6 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,19 @@ ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject,
public = list(
Project = function(cols) {
# cols is either a character vector or a named list of Expressions
if (!is.character(cols)) {
# We don't yet support mutate() on datasets, so this is just a list
# of FieldRefs, and we need to back out the field names
cols <- get_field_names(cols)
if (is.character(cols)) {
dataset___ScannerBuilder__ProjectNames(self, cols)
} else {
# If we have expressions, but they all turn out to be field_refs,
# we can still call the simple method
field_names <- get_field_names(cols)
if (all(nzchar(field_names))) {
dataset___ScannerBuilder__ProjectNames(self, field_names)
} else {
# Else, we are projecting/mutating
dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols))
}
}
assert_is(cols, "character")
dataset___ScannerBuilder__Project(self, cols)
self
},
Filter = function(expr) {
Expand Down
15 changes: 11 additions & 4 deletions r/R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,6 @@ mutate.arrow_dplyr_query <- function(.data,
}

.data <- arrow_dplyr_query(.data)
if (query_on_dataset(.data)) {
not_implemented_for_dataset("mutate()")
}

.keep <- match.arg(.keep)
.before <- enquo(.before)
Expand All @@ -529,6 +526,7 @@ mutate.arrow_dplyr_query <- function(.data,
# Deparse and take the first element in case they're long expressions
names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label)

is_dataset <- query_on_dataset(.data)
mask <- arrow_mask(.data)
results <- list()
for (i in seq_along(exprs)) {
Expand All @@ -539,6 +537,15 @@ mutate.arrow_dplyr_query <- function(.data,
if (inherits(results[[new_var]], "try-error")) {
msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow')
return(abandon_ship(call, .data, msg))
} else if (is_dataset &&
!inherits(results[[new_var]], "Expression") &&
!is.null(results[[new_var]])) {
# We need some wrapping to handle literal values
if (length(results[[new_var]]) != 1) {
msg <- paste0('In ', new_var, " = ", as_label(exprs[[i]]), ", only values of size one are recycled")
return(abandon_ship(call, .data, msg))
}
results[[new_var]] <- Expression$scalar(results[[new_var]])
}
# Put it in the data mask too
mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]]
Expand Down Expand Up @@ -583,7 +590,7 @@ abandon_ship <- function(call, .data, msg = NULL) {
# Default message: function not implemented
not_implemented_for_dataset(paste0(dplyr_fun_name, "()"))
} else {
stop(msg, call. = FALSE)
stop(msg, "\nCall collect() first to pull data into R.", call. = FALSE)
}
}

Expand Down
20 changes: 16 additions & 4 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 15 additions & 2 deletions r/src/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,24 @@ std::shared_ptr<ds::PartitioningFactory> dataset___HivePartitioning__MakeFactory
// ScannerBuilder, Scanner

// [[arrow::export]]
void dataset___ScannerBuilder__Project(const std::shared_ptr<ds::ScannerBuilder>& sb,
const std::vector<std::string>& cols) {
void dataset___ScannerBuilder__ProjectNames(const std::shared_ptr<ds::ScannerBuilder>& sb,
const std::vector<std::string>& cols) {
StopIfNotOk(sb->Project(cols));
}

// [[arrow::export]]
void dataset___ScannerBuilder__ProjectExprs(
const std::shared_ptr<ds::ScannerBuilder>& sb,
const std::vector<std::shared_ptr<ds::Expression>>& exprs,
const std::vector<std::string>& names) {
// We have shared_ptrs of expressions but need the Expressions
std::vector<ds::Expression> expressions;
for (auto expr : exprs) {
expressions.push_back(*expr);
}
StopIfNotOk(sb->Project(expressions, names));
}

// [[arrow::export]]
void dataset___ScannerBuilder__Filter(const std::shared_ptr<ds::ScannerBuilder>& sb,
const std::shared_ptr<ds::Expression>& expr) {
Expand Down
7 changes: 5 additions & 2 deletions r/src/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ std::shared_ptr<ds::Expression> dataset___expr__field_ref(std::string name) {
// [[arrow::export]]
std::string dataset___expr__get_field_ref_name(
const std::shared_ptr<ds::Expression>& ref) {
auto refname = ref->field_ref()->name();
return *refname;
auto field_ref = ref->field_ref();
if (field_ref == nullptr) {
return "";
}
return *field_ref->name();
}

// [[arrow::export]]
Expand Down
93 changes: 91 additions & 2 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,96 @@ test_that("filter() with expressions", {
)
})

test_that("mutate()", {
ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
mutated <- ds %>%
select(chr, dbl, int) %>%
filter(dbl * 2 > 14 & dbl - 50 < 3L) %>%
mutate(twice = int * 2)
expect_output(
print(mutated),
"FileSystemDataset (query)
chr: string
dbl: double
int: int32
twice: expr
* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3))
See $.data for the source Arrow object",
fixed = TRUE
)
expect_equivalent(
mutated %>%
collect() %>%
arrange(dbl),
rbind(
df1[8:10, c("chr", "dbl", "int")],
df2[1:2, c("chr", "dbl", "int")]
) %>%
mutate(
twice = int * 2
)
)
})

test_that("transmute()", {
ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
mutated <-
expect_equivalent(
ds %>%
select(chr, dbl, int) %>%
filter(dbl * 2 > 14 & dbl - 50 < 3L) %>%
transmute(twice = int * 2) %>%
collect() %>%
arrange(twice),
rbind(
df1[8:10, "int", drop = FALSE],
df2[1:2, "int", drop = FALSE]
) %>%
transmute(
twice = int * 2
)
)
})

test_that("mutate() features not yet implemented", {
expect_error(
ds %>%
group_by(int) %>%
mutate(avg = mean(int)),
"mutate() on grouped data not supported in Arrow\nCall collect() first to pull data into R.",
fixed = TRUE
)
})


test_that("mutate() with scalar (length 1) literal inputs", {
expect_equal(
ds %>%
mutate(the_answer = 42) %>%
collect() %>%
pull(the_answer),
rep(42, nrow(ds))
)

expect_error(
ds %>% mutate(the_answer = c(42, 42)),
"In the_answer = c(42, 42), only values of size one are recycled\nCall collect() first to pull data into R.",
fixed = TRUE
)
})

test_that("mutate() with NULL inputs", {
expect_equal(
ds %>%
mutate(int = NULL) %>%
collect(),
ds %>%
select(-int) %>%
collect()
)
})

test_that("filter scalar validation doesn't crash (ARROW-7772)", {
expect_error(
ds %>%
Expand Down Expand Up @@ -832,7 +922,6 @@ test_that("dplyr method not implemented messages", {
expect_error(x, "is not currently implemented for Arrow Datasets")
}
expect_not_implemented(ds %>% arrange(int))
expect_not_implemented(ds %>% mutate(int = int + 2))
expect_not_implemented(ds %>% filter(int == 1) %>% summarize(n()))
})

Expand Down Expand Up @@ -1137,7 +1226,7 @@ test_that("Dataset writing: no partitioning", {
test_that("Dataset writing: partition on null", {
skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651
ds <- open_dataset(hive_dir)

dst_dir <- tempfile()
partitioning = hive_partition(lgl = boolean())
write_dataset(ds, dst_dir, partitioning = partitioning)
Expand Down

0 comments on commit 906331c

Please sign in to comment.