diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 8d87fa1be1678..860d947cd48e4 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -224,13 +224,7 @@ ExecPlan <- R6Class("ExecPlan", Stop = function() ExecPlan_StopProducing(self) ), private = list( - .set_aggregation = function(node, data, grouped, group_vars) { - # Project to include just the data required for each aggregation, - # plus group_by_vars (last) - # TODO: validate that none of names(aggregations) are the same as names(group_by_vars) - # dplyr does not error on this but the result it gives isn't great - node <- node$Project(summarize_projection(data)) - + .set_aggregate_func_names = function(data, grouped) { if (grouped) { # We need to prefix all of the aggregation function names with "hash_" data$aggregations <- lapply(data$aggregations, function(x) { @@ -238,16 +232,9 @@ ExecPlan <- R6Class("ExecPlan", x }) } - target_names <- names(data$aggregations) - for (i in seq_len(length(target_names))) { - data$aggregations[[i]][["name"]] <- data$aggregations[[i]][["target"]] <- target_names[i] - } - - node <- node$Aggregate( - options = data$aggregations, - key_names = group_vars - ) - + data + }, + .set_group_by = function(node, data, group_vars, grouped) { if (grouped) { # The result will have result columns first then the grouping cols. # dplyr orders group cols first, so adapt the result to meet that expectation. @@ -264,6 +251,27 @@ ExecPlan <- R6Class("ExecPlan", } } return(list(node = node, data = data)) + }, + .set_aggregation = function(node, data, grouped, group_vars) { + # Project to include just the data required for each aggregation, + # plus group_by_vars (last) + # TODO: validate that none of names(aggregations) are the same as names(group_by_vars) + # dplyr does not error on this but the result it gives isn't great + node <- node$Project(summarize_projection(data)) + + data <- private$.set_aggregate_func_names(data, grouped) + + target_names <- names(data$aggregations) + for (i in seq_len(length(target_names))) { + data$aggregations[[i]][["name"]] <- data$aggregations[[i]][["target"]] <- target_names[i] + } + + node <- node$Aggregate( + options = data$aggregations, + key_names = group_vars + ) + group_config <- private$.set_group_by(node, data, group_vars, grouped) + return(list(node = group_config$node, data = group_config$data)) } ) )