Skip to content

Commit

Permalink
apacheGH-34911: [C++] Add first and last aggregator (apache#34912)
Browse files Browse the repository at this point in the history
### Rationale for this change
This PR adds "first" and "last" aggregator and support using those with Acero's segmented aggregation.

### What changes are included in this PR?
- [x] Numeric Scalar Aggregator (bool, int types, floating types)
- [x] Numeric Hash Aggregator (bool, int types, floating types)
- [x] Docstring
- [x] Non-Numeric Scalar Aggregator (string, binary, fixed binary, temporal)
- [x] Non-Numeric Hash Aggregator (string, binary, fixed binary, temporal)
- [x] Add `ordered` flag in aggregate kernels
- [x] Implement and test skip null
- [x] Update compute.rst

### Are these changes tested?
- [x] Compute Kernel Test (Scalar Kernels, all supported datatypes)
- [x] Hash Aggregate Test (Hash Kernels, all supported datatypes)
- [x] Segmented Aggregation Test (Both Scalar and Hash Kernels)

### Are there any user-facing changes?
Yes. Added First and Last aggregator.

Authored-by: Li Jin <ice.xelloss@gmail.com>
Signed-off-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
icexelloss authored and liujiacheng777 committed May 11, 2023
1 parent 2f87365 commit 5a80507
Show file tree
Hide file tree
Showing 13 changed files with 1,648 additions and 88 deletions.
60 changes: 43 additions & 17 deletions cpp/src/arrow/acero/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,14 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
kernel_intypes[i] = in_types;
ARROW_ASSIGN_OR_RAISE(const Kernel* kernel,
function->DispatchExact(kernel_intypes[i]));
kernels[i] = static_cast<const ScalarAggregateKernel*>(kernel);
const ScalarAggregateKernel* agg_kernel =
static_cast<const ScalarAggregateKernel*>(kernel);
if (concurrency > 1 && agg_kernel->ordered) {
return Status::NotImplemented(
"Using ordered aggregator in multiple threaded execution is not supported");
}

kernels[i] = agg_kernel;

if (aggregates[i].options == nullptr) {
DCHECK(!function->doc().options_required);
Expand Down Expand Up @@ -391,12 +398,13 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
auto aggregates = aggregate_options.aggregates;
const auto& keys = aggregate_options.keys;
const auto& segment_keys = aggregate_options.segment_keys;
const auto concurreny =
plan->query_context()->exec_context()->executor()->GetCapacity();

if (keys.size() > 0) {
return Status::Invalid("Scalar aggregation with some key");
}
if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
segment_keys.size() > 0) {
if (concurreny > 1 && segment_keys.size() > 0) {
return Status::NotImplemented("Segmented aggregation in a multi-threaded plan");
}

Expand All @@ -406,7 +414,16 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
ARROW_ASSIGN_OR_RAISE(
auto args,
MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggregates, exec_ctx,
/*concurrency=*/plan->query_context()->max_concurrency()));
/*concurrency=*/concurreny));

if (concurreny > 1) {
for (auto& kernel : args.kernels) {
if (kernel->ordered) {
return Status::NotImplemented(
"Using ordered aggregator in multiple threaded execution is not supported");
}
}
}

return plan->EmplaceNode<ScalarAggregateNode>(
plan, std::move(inputs), std::move(args.output_schema), std::move(args.segmenter),
Expand Down Expand Up @@ -599,7 +616,7 @@ class GroupByNode : public ExecNode, public TracedNode {
static Result<AggregateNodeArgs<HashAggregateKernel>> MakeAggregateNodeArgs(
const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& keys,
const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& aggs,
ExecContext* ctx) {
ExecContext* ctx, const int concurrency) {
// Find input field indices for key fields
std::vector<int> key_field_ids(keys.size());
for (size_t i = 0; i < keys.size(); ++i) {
Expand Down Expand Up @@ -656,6 +673,20 @@ class GroupByNode : public ExecNode, public TracedNode {
// Construct aggregates
ARROW_ASSIGN_OR_RAISE(auto agg_kernels, GetKernels(ctx, aggs, agg_src_types));

if (concurrency > 1) {
if (segment_keys.size() > 0) {
return Status::NotImplemented(
"Segmented aggregation in a multi-threaded execution context");
}

for (auto kernel : agg_kernels) {
if (kernel->ordered) {
return Status::NotImplemented(
"Using ordered aggregator in multiple threaded execution is not supported");
}
}
}

ARROW_ASSIGN_OR_RAISE(auto agg_states,
InitKernels(agg_kernels, ctx, aggs, agg_src_types));

Expand Down Expand Up @@ -703,18 +734,13 @@ class GroupByNode : public ExecNode, public TracedNode {
const auto& keys = aggregate_options.keys;
const auto& segment_keys = aggregate_options.segment_keys;
auto aggs = aggregate_options.aggregates;

if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 &&
segment_keys.size() > 0) {
return Status::NotImplemented(
"Segmented aggregation in a multi-threaded execution context");
}
auto concurrency = plan->query_context()->exec_context()->executor()->GetCapacity();

const auto& input_schema = input->output_schema();
auto exec_ctx = plan->query_context()->exec_context();

ARROW_ASSIGN_OR_RAISE(auto args, MakeAggregateNodeArgs(input_schema, keys,
segment_keys, aggs, exec_ctx));
ARROW_ASSIGN_OR_RAISE(
auto args, MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggs, exec_ctx,
concurrency));

return input->plan()->EmplaceNode<GroupByNode>(
input, std::move(args.output_schema), std::move(args.grouping_key_field_ids),
Expand Down Expand Up @@ -1042,9 +1068,9 @@ Result<std::shared_ptr<Schema>> MakeOutputSchema(
exec_ctx, /*concurrency=*/1));
return std::move(args.output_schema);
} else {
ARROW_ASSIGN_OR_RAISE(
auto args, GroupByNode::MakeAggregateNodeArgs(input_schema, keys, segment_keys,
aggregates, exec_ctx));
ARROW_ASSIGN_OR_RAISE(auto args, GroupByNode::MakeAggregateNodeArgs(
input_schema, keys, segment_keys, aggregates,
exec_ctx, /*concurrency=*/1));
return std::move(args.output_schema);
}
}
Expand Down
Loading

0 comments on commit 5a80507

Please sign in to comment.