Skip to content

Commit

Permalink
partition the function further
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Jun 23, 2022
1 parent dfef8ec commit 038595b
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions r/R/query-engine.R
Original file line number Diff line number Diff line change
Expand Up @@ -224,30 +224,17 @@ ExecPlan <- R6Class("ExecPlan",
Stop = function() ExecPlan_StopProducing(self)
),
private = list(
.set_aggregation = function(node, data, grouped, group_vars) {
# Project to include just the data required for each aggregation,
# plus group_by_vars (last)
# TODO: validate that none of names(aggregations) are the same as names(group_by_vars)
# dplyr does not error on this but the result it gives isn't great
node <- node$Project(summarize_projection(data))

.set_aggregate_func_names = function(data, grouped) {
if (grouped) {
# We need to prefix all of the aggregation function names with "hash_"
data$aggregations <- lapply(data$aggregations, function(x) {
x[["fun"]] <- paste0("hash_", x[["fun"]])
x
})
}
target_names <- names(data$aggregations)
for (i in seq_len(length(target_names))) {
data$aggregations[[i]][["name"]] <- data$aggregations[[i]][["target"]] <- target_names[i]
}

node <- node$Aggregate(
options = data$aggregations,
key_names = group_vars
)

data
},
.set_group_by = function(node, data, group_vars, grouped) {
if (grouped) {
# The result will have result columns first then the grouping cols.
# dplyr orders group cols first, so adapt the result to meet that expectation.
Expand All @@ -264,6 +251,27 @@ ExecPlan <- R6Class("ExecPlan",
}
}
return(list(node = node, data = data))
},
.set_aggregation = function(node, data, grouped, group_vars) {
# Project to include just the data required for each aggregation,
# plus group_by_vars (last)
# TODO: validate that none of names(aggregations) are the same as names(group_by_vars)
# dplyr does not error on this but the result it gives isn't great
node <- node$Project(summarize_projection(data))

data <- private$.set_aggregate_func_names(data, grouped)

target_names <- names(data$aggregations)
for (i in seq_len(length(target_names))) {
data$aggregations[[i]][["name"]] <- data$aggregations[[i]][["target"]] <- target_names[i]
}

node <- node$Aggregate(
options = data$aggregations,
key_names = group_vars
)
group_config <- private$.set_group_by(node, data, group_vars, grouped)
return(list(node = group_config$node, data = group_config$data))
}
)
)
Expand Down

0 comments on commit 038595b

Please sign in to comment.