Skip to content

Commit

Permalink
ARROW-15260: [R] open_dataset - add file_name as column (#12826)
Browse files Browse the repository at this point in the history
Authored-by: Nic Crane <thisisnic@gmail.com>
Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
  • Loading branch information
thisisnic committed Aug 10, 2022
1 parent b3116fa commit 8386871
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 8 deletions.
1 change: 1 addition & 0 deletions r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Collate:
'dplyr-distinct.R'
'dplyr-eval.R'
'dplyr-filter.R'
'dplyr-funcs-augmented.R'
'dplyr-funcs-conditional.R'
'dplyr-funcs-datetime.R'
'dplyr-funcs-math.R'
Expand Down
1 change: 1 addition & 0 deletions r/R/dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ open_dataset <- function(sources,
# and not handle_parquet_io_error()
error = function(e, call = caller_env(n = 4)) {
handle_parquet_io_error(e, format, call)
abort(conditionMessage(e), call = call)
}
)
}
Expand Down
11 changes: 11 additions & 0 deletions r/R/dplyr-collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
# and not handle_csv_read_error()
error = function(e, call = caller_env(n = 4)) {
handle_csv_read_error(e, x$.data$schema, call)
handle_augmented_field_misuse(e, call)
abort(conditionMessage(e), call = call)
}
)

Expand Down Expand Up @@ -104,10 +106,18 @@ add_suffix <- function(fields, common_cols, suffix) {
}

implicit_schema <- function(.data) {
# Get the source data schema so that we can evaluate expressions to determine
# the output schema. Note that we don't use source_data() because we only
# want to go one level up (where we may have called implicit_schema() before)
.data <- ensure_group_vars(.data)
old_schm <- .data$.data$schema
# Add in any augmented fields that may exist in the query but not in the
# real data, in case we have FieldRefs to them
old_schm[["__filename"]] <- string()

if (is.null(.data$aggregations)) {
# .data$selected_columns is a named list of Expressions (FieldRefs or
# something more complex). Bind them in order to determine their output type
new_fields <- map(.data$selected_columns, ~ .$type(old_schm))
if (!is.null(.data$join) && !(.data$join$type %in% JoinType[1:4])) {
# Add cols from right side, except for semi/anti joins
Expand All @@ -128,6 +138,7 @@ implicit_schema <- function(.data) {
new_fields <- c(left_fields, right_fields)
}
} else {
# The output schema is based on the aggregations and any group_by vars
new_fields <- map(summarize_projection(.data), ~ .$type(old_schm))
# * Put group_by_vars first (this can't be done by summarize,
# they have to be last per the aggregate node signature,
Expand Down
22 changes: 22 additions & 0 deletions r/R/dplyr-funcs-augmented.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

register_bindings_augmented <- function() {
register_binding("add_filename", function() {
Expression$field_ref("__filename")
})
}
1 change: 1 addition & 0 deletions r/R/dplyr-funcs.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ create_binding_cache <- function() {
register_bindings_math()
register_bindings_string()
register_bindings_type()
register_bindings_augmented()

# We only create the cache for nse_funcs and not agg_funcs
.cache$functions <- c(as.list(nse_funcs), arrow_funcs)
Expand Down
3 changes: 3 additions & 0 deletions r/R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ make_field_refs <- function(field_names) {
#' @export
print.arrow_dplyr_query <- function(x, ...) {
schm <- x$.data$schema
# If we are using this augmented field, it won't be in the schema
schm[["__filename"]] <- string()

types <- map_chr(x$selected_columns, function(expr) {
name <- expr$field_name
if (nzchar(name)) {
Expand Down
31 changes: 29 additions & 2 deletions r/R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ read_compressed_error <- function(e) {
stop(e)
}

# This function was refactored in ARROW-15260 to only raise an error if
# the appropriate string was found and so errors must be raised manually after
# calling this if matching error not found
# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
handle_parquet_io_error <- function(e, format, call) {
msg <- conditionMessage(e)
if (grepl("Parquet magic bytes not found in footer", msg) && length(format) > 1 && is_character(format)) {
Expand All @@ -143,8 +147,8 @@ handle_parquet_io_error <- function(e, format, call) {
msg,
i = "Did you mean to specify a 'format' other than the default (parquet)?"
)
abort(msg, call = call)
}
abort(msg, call = call)
}

as_writable_table <- function(x) {
Expand Down Expand Up @@ -205,6 +209,10 @@ repeat_value_as_array <- function(object, n) {
return(Scalar$create(object)$as_array(n))
}

# This function was refactored in ARROW-15260 to only raise an error if
# the appropriate string was found and so errors must be raised manually after
# calling this if matching error not found
# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
handle_csv_read_error <- function(e, schema, call) {
msg <- conditionMessage(e)

Expand All @@ -217,8 +225,27 @@ handle_csv_read_error <- function(e, schema, call) {
"header being read in as data."
)
)
abort(msg, call = call)
}
}

# This function only raises an error if
# the appropriate string was found and so errors must be raised manually after
# calling this if matching error not found
# TODO: Refactor as part of ARROW-17355 to prevent potential missed errors
handle_augmented_field_misuse <- function(e, call) {
msg <- conditionMessage(e)
if (grepl("No match for FieldRef.Name(__filename)", msg, fixed = TRUE)) {
msg <- c(
msg,
i = paste(
"`add_filename()` or use of the `__filename` augmented field can only",
"be used with with Dataset objects, and can only be added before doing",
"an aggregation or a join."
)
)
abort(msg, call = call)
}
abort(msg, call = call)
}

is_compressed <- function(compression) {
Expand Down
8 changes: 3 additions & 5 deletions r/src/compute-exec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,7 @@ std::shared_ptr<compute::ExecNode> ExecNode_Scan(

options->dataset_schema = dataset->schema();

// ScanNode needs the filter to do predicate pushdown and skip partitions
options->filter = ValueOrStop(filter->Bind(*dataset->schema()));
options->filter = *filter;

// ScanNode needs to know which fields to materialize (and which are unnecessary)
std::vector<compute::Expression> exprs;
Expand All @@ -232,9 +231,8 @@ std::shared_ptr<compute::ExecNode> ExecNode_Scan(
}

options->projection =
ValueOrStop(call("make_struct", std::move(exprs),
compute::MakeStructOptions{std::move(materialized_field_names)})
.Bind(*dataset->schema()));
call("make_struct", std::move(exprs),
compute::MakeStructOptions{std::move(materialized_field_names)});

return MakeExecNodeOrStop("scan", plan.get(), {},
ds::ScanNodeOptions{dataset, options});
Expand Down
94 changes: 93 additions & 1 deletion r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,6 @@ test_that("dataset to C-interface to arrow_dplyr_query with proj/filter", {
delete_arrow_array_stream(stream_ptr)
})


test_that("Filter parquet dataset with is.na ARROW-15312", {
ds_path <- make_temp_dir()

Expand Down Expand Up @@ -1349,3 +1348,96 @@ test_that("FileSystemFactoryOptions input validation", {
fixed = TRUE
)
})

test_that("can add in augmented fields", {
ds <- open_dataset(hive_dir)

observed <- ds %>%
mutate(file_name = add_filename()) %>%
collect()

expect_named(
observed,
c("int", "dbl", "lgl", "chr", "fct", "ts", "group", "other", "file_name")
)

expect_equal(
sort(unique(observed$file_name)),
list.files(hive_dir, full.names = TRUE, recursive = TRUE)
)

error_regex <- paste(
"`add_filename()` or use of the `__filename` augmented field can only",
"be used with with Dataset objects, and can only be added before doing",
"an aggregation or a join."
)

# errors appropriately with ArrowTabular objects
expect_error(
arrow_table(mtcars) %>%
mutate(file = add_filename()) %>%
collect(),
regexp = error_regex,
fixed = TRUE
)

# errors appropriately with aggregation
expect_error(
ds %>%
summarise(max_int = max(int)) %>%
mutate(file_name = add_filename()) %>%
collect(),
regexp = error_regex,
fixed = TRUE
)

# joins to tables
another_table <- select(example_data, int, dbl2)
expect_error(
ds %>%
left_join(another_table, by = "int") %>%
mutate(file = add_filename()) %>%
collect(),
regexp = error_regex,
fixed = TRUE
)

# and on joins to datasets
another_dataset <- write_dataset(another_table, "another_dataset")
expect_error(
ds %>%
left_join(open_dataset("another_dataset"), by = "int") %>%
mutate(file = add_filename()) %>%
collect(),
regexp = error_regex,
fixed = TRUE
)

# this hits the implicit_schema path by joining afterwards
join_after <- ds %>%
mutate(file = add_filename()) %>%
left_join(open_dataset("another_dataset"), by = "int") %>%
collect()

expect_named(
join_after,
c("int", "dbl", "lgl", "chr", "fct", "ts", "group", "other", "file", "dbl2")
)

expect_equal(
sort(unique(join_after$file)),
list.files(hive_dir, full.names = TRUE, recursive = TRUE)
)

# another test on the explicit_schema path
summarise_after <- ds %>%
mutate(file = add_filename()) %>%
group_by(file) %>%
summarise(max_int = max(int)) %>%
collect()

expect_equal(
sort(summarise_after$file),
list.files(hive_dir, full.names = TRUE, recursive = TRUE)
)
})

0 comments on commit 8386871

Please sign in to comment.