From 80bba299612d2a8a92968fd15e81343f0783e600 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Sat, 27 Aug 2022 17:44:53 -0400 Subject: [PATCH] ARROW-17463: [R] Avoid unnecessary projections (#13954) Before: ``` > mtcars |> arrow_table() |> count(cyl) |> explain() ExecPlan with 6 nodes: 5:SinkNode{} 4:ProjectNode{projection=[cyl, n]} 3:ProjectNode{projection=[cyl, n]} 2:GroupByNode{keys=["cyl"], aggregates=[ hash_sum(n, {skip_nulls=true, min_count=1}), ]} 1:ProjectNode{projection=["n": 1, cyl]} 0:TableSourceNode{} ``` After: ``` ExecPlan with 5 nodes: 4:SinkNode{} 3:ProjectNode{projection=[cyl, n]} 2:GroupByNode{keys=["cyl"], aggregates=[ hash_sum(n, {skip_nulls=true, min_count=1}), ]} 1:ProjectNode{projection=["n": 1, cyl]} 0:TableSourceNode{} ``` Authored-by: Neal Richardson Signed-off-by: Neal Richardson --- r/R/query-engine.R | 24 ++++++-- r/tests/testthat/test-dplyr-collapse.R | 36 +++++++++++ r/tests/testthat/test-dplyr-query.R | 82 ++++++++++++++++--------- r/tests/testthat/test-dplyr-summarize.R | 41 ++++++++++++- 4 files changed, 147 insertions(+), 36 deletions(-) diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 84360490fdbe7..c132b291b872b 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -142,12 +142,14 @@ ExecPlan <- R6Class("ExecPlan", } } else { # If any columns are derived, reordered, or renamed we need to Project - # If there are aggregations, the projection was already handled above + # If there are aggregations, the projection was already handled above. # We have to project at least once to eliminate some junk columns # that the ExecPlan adds: # __fragment_index, __batch_index, __last_in_fragment - # Presumably extraneous repeated projection of the same thing - # (as when we've done collapse() and not projected after) is cheap/no-op + # + # $Project() will check whether we actually need to project, so that + # repeated projection of the same thing + # (as when we've done collapse() and not projected after) is avoided projection <- c(.data$selected_columns, .data$temp_columns) node <- node$Project(projection) if (!is.null(.data$join)) { @@ -349,7 +351,11 @@ ExecNode <- R6Class("ExecNode", Project = function(cols) { if (length(cols)) { assert_is_list_of(cols, "Expression") - self$preserve_extras(ExecNode_Project(self, cols, names(cols))) + if (needs_projection(cols, self$schema)) { + self$preserve_extras(ExecNode_Project(self, cols, names(cols))) + } else { + self + } } else { self$preserve_extras(ExecNode_Project(self, character(0), character(0))) } @@ -402,3 +408,13 @@ do_exec_plan_substrait <- function(substrait_plan) { plan <- ExecPlan$create() ExecPlan_run_substrait(plan, substrait_plan) } + +needs_projection <- function(projection, schema) { + # Check whether `projection` would do anything to data with the given `schema` + field_names <- set_names(map_chr(projection, ~ .$field_name), NULL) + + # We need to apply `projection` if: + !all(nzchar(field_names)) || # Any of the Expressions are not FieldRefs + !identical(field_names, names(projection)) || # Any fields are renamed + !identical(field_names, names(schema)) # The fields are reordered +} diff --git a/r/tests/testthat/test-dplyr-collapse.R b/r/tests/testthat/test-dplyr-collapse.R index 3c121780da64d..f1b4f9cea3a46 100644 --- a/r/tests/testthat/test-dplyr-collapse.R +++ b/r/tests/testthat/test-dplyr-collapse.R @@ -242,3 +242,39 @@ test_that("query_on_dataset handles collapse()", { select(int) )) }) + +test_that("collapse doesn't unnecessarily add ProjectNodes", { + plan <- capture.output( + tab %>% + collapse() %>% + collapse() %>% + show_query() + ) + # There should be no projections + expect_length(grep("ProjectNode", plan), 0) + + plan <- capture.output( + tab %>% + select(int, chr) %>% + collapse() %>% + collapse() %>% + show_query() + ) + # There should be just one projection + expect_length(grep("ProjectNode", plan), 1) + + skip_if_not_available("dataset") + # We need one ProjectNode on dataset queries to handle augmented fields + + tf <- tempfile() + write_dataset(tab, tf, partitioning = "lgl") + ds <- open_dataset(tf) + + plan <- capture.output( + ds %>% + collapse() %>% + collapse() %>% + show_query() + ) + expect_length(grep("ProjectNode", plan), 1) +}) diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index 37ab178cbb40f..1a5b6ec8a7c76 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -448,9 +448,9 @@ test_that("show_exec_plan(), show_query() and explain()", { arrow_table() %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan - "ProjectNode.*", # output columns - "TableSourceNode" # entry point + "ExecPlan with 2 nodes:.*", # boiler plate for ExecPlan + "SinkNode.*", # output + "TableSourceNode" # entry point ) ) @@ -463,12 +463,12 @@ test_that("show_exec_plan(), show_query() and explain()", { mutate(int_plus_ten = int + 10) %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan - "chr, int, lgl, \"int_plus_ten\".*", # selected columns - "FilterNode.*", # filter node - "(dbl > 2).*", # filter expressions + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "chr, int, lgl, \"int_plus_ten\".*", # selected columns + "FilterNode.*", # filter node + "(dbl > 2).*", # filter expressions "chr != \"e\".*", - "TableSourceNode" # entry point + "TableSourceNode" # entry point ) ) @@ -481,11 +481,11 @@ test_that("show_exec_plan(), show_query() and explain()", { mutate(int_plus_ten = int + 10) %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan - "chr, int, lgl, \"int_plus_ten\".*", # selected columns - "(dbl > 2).*", # the filter expressions + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "chr, int, lgl, \"int_plus_ten\".*", # selected columns + "(dbl > 2).*", # the filter expressions "chr != \"e\".*", - "TableSourceNode" # the entry point" + "TableSourceNode" # the entry point" ) ) @@ -497,13 +497,13 @@ test_that("show_exec_plan(), show_query() and explain()", { summarise(avg = mean(dbl, na.rm = TRUE)) %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan - "ProjectNode.*", # output columns - "GroupByNode.*", # the group_by statement - "keys=.*lgl.*", # the key for the aggregations - "aggregates=.*hash_mean.*avg.*", # the aggregations - "ProjectNode.*", # the input columns - "TableSourceNode" # the entry point + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "ProjectNode.*", # output columns + "GroupByNode.*", # the group_by statement + "keys=.*lgl.*", # the key for the aggregations + "aggregates=.*hash_mean.*avg.*", # the aggregations + "ProjectNode.*", # the input columns + "TableSourceNode" # the entry point ) ) @@ -521,14 +521,13 @@ test_that("show_exec_plan(), show_query() and explain()", { select(int, verses, doubled_dbl) %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan - "ProjectNode.*", # output columns - "HashJoinNode.*", # the join - "ProjectNode.*", # input columns for the second table + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "ProjectNode.*", # output columns + "HashJoinNode.*", # the join + "ProjectNode.*", # input columns for the second table "\"doubled_dbl\"\\: multiply_checked\\(dbl, 2\\).*", # mutate - "TableSourceNode.*", # second table - "ProjectNode.*", # input columns for the first table - "TableSourceNode" # first table + "TableSourceNode.*", # second table + "TableSourceNode" # first table ) ) @@ -539,11 +538,10 @@ test_that("show_exec_plan(), show_query() and explain()", { arrange(desc(wt)) %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan "OrderBySinkNode.*wt.*DESC.*", # arrange goes via the OrderBy sink node - "ProjectNode.*", # output columns - "FilterNode.*", # filter node - "TableSourceNode.*" # entry point + "FilterNode.*", # filter node + "TableSourceNode.*" # entry point ) ) @@ -559,3 +557,27 @@ test_that("show_exec_plan(), show_query() and explain()", { "The `ExecPlan` cannot be printed for a nested query." ) }) + +test_that("needs_projection unit tests", { + tab <- Table$create(tbl) + # Wrapper to simplify tests + query_needs_projection <- function(query) { + needs_projection(query$selected_columns, tab$schema) + } + expect_false(query_needs_projection(as_adq(tab))) + expect_false(query_needs_projection( + tab %>% collapse() %>% collapse() + )) + expect_true(query_needs_projection( + tab %>% mutate(int = int + 2) + )) + expect_true(query_needs_projection( + tab %>% select(int, chr) + )) + expect_true(query_needs_projection( + tab %>% rename(int2 = int) + )) + expect_true(query_needs_projection( + tab %>% relocate(lgl) + )) +}) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index f799fcbf38487..0ee0c5739dbb6 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -243,8 +243,10 @@ test_that("n_distinct() with many batches", { write_parquet(dplyr::starwars, tf, chunk_size = 20) ds <- open_dataset(tf) - expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(), - ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE))) + expect_equal( + ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(), + ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE)) + ) }) test_that("n_distinct() on dataset", { @@ -1089,3 +1091,38 @@ test_that("summarise() supports namespacing", { tbl ) }) + +test_that("We don't add unnecessary ProjectNodes when aggregating", { + tab <- Table$create(tbl) + + # Wrapper to simplify the tests + expect_project_nodes <- function(query, n) { + plan <- capture.output(query %>% show_query()) + expect_length(grep("ProjectNode", plan), n) + } + + # 1 Projection: select int as `mean(int)` before aggregation + expect_project_nodes( + tab %>% summarize(mean(int)), + 1 + ) + + # 0 Projections only if + # (a) input only contains the col you're aggregating, and + # (b) the output col name is the same as the input name, and + # (c) no grouping + expect_project_nodes( + tab[, "int"] %>% summarize(int = mean(int, na.rm = TRUE)), + 0 + ) + + # 2 projections: one before, and one after in order to put grouping cols first + expect_project_nodes( + tab %>% group_by(lgl) %>% summarize(mean(int)), + 2 + ) + expect_project_nodes( + tab %>% count(lgl), + 2 + ) +})