Skip to content

Commit

Permalink
Part 2 of refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
save-buffer committed Oct 3, 2022
1 parent 5fd54bc commit 279bf83
Show file tree
Hide file tree
Showing 30 changed files with 610 additions and 921 deletions.
6 changes: 3 additions & 3 deletions c_glib/arrow-glib/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1914,18 +1914,18 @@ garrow_execute_plan_start(GArrowExecutePlan *plan,
}

/**
* garrow_execute_plan_stop:
* garrow_execute_plan_abort:
* @plan: A #GArrowExecutePlan.
*
* Stops this plan.
*
* Since: 6.0.0
*/
void
garrow_execute_plan_stop(GArrowExecutePlan *plan)
garrow_execute_plan_abort(GArrowExecutePlan *plan)
{
auto arrow_plan = garrow_execute_plan_get_raw(plan);
arrow_plan->StopProducing();
arrow_plan->Abort();
}

/**
Expand Down
21 changes: 10 additions & 11 deletions cpp/examples/arrow/compute_register_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,29 @@ class ExampleNode : public cp::ExecNode {
ExampleNode(ExecNode* input, const ExampleNodeOptions&)
: ExecNode(/*plan=*/input->plan(), /*inputs=*/{input},
/*input_labels=*/{"ignored"},
/*output_schema=*/input->output_schema(), /*num_outputs=*/1) {}
/*output_schema=*/input->output_schema()) {}

const char* kind_name() const override { return "ExampleNode"; }

arrow::Status StartProducing() override {
outputs_[0]->InputFinished(this, 0);
return arrow::Status::OK();
}
arrow::Status StartProducing() override { return output_->InputFinished(this, 0); }

void ResumeProducing(ExecNode* output, int32_t counter) override {
inputs_[0]->ResumeProducing(this, counter);
}

void PauseProducing(ExecNode* output, int32_t counter) override {
inputs_[0]->PauseProducing(this, counter);
}

void StopProducing(ExecNode* output) override { inputs_[0]->StopProducing(this); }
void StopProducing() override { inputs_[0]->StopProducing(); }
arrow::Status InputReceived(ExecNode* input, cp::ExecBatch batch) override {
return arrow::Status::OK();
}

void InputReceived(ExecNode* input, cp::ExecBatch batch) override {}
void ErrorReceived(ExecNode* input, arrow::Status error) override {}
void InputFinished(ExecNode* input, int total_batches) override {}
arrow::Status InputFinished(ExecNode* input, int total_batches) override {
return arrow::Status::OK();
}

arrow::Future<> finished() override { return inputs_[0]->finished(); }
void Abort() override {}
};

arrow::Result<cp::ExecNode*> ExampleExecNodeFactory(cp::ExecPlan* plan,
Expand Down
106 changes: 28 additions & 78 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ class ScalarAggregateNode : public ExecNode {
std::vector<const ScalarAggregateKernel*> kernels,
std::vector<std::vector<std::unique_ptr<KernelState>>> states)
: ExecNode(plan, std::move(inputs), {"target"},
/*output_schema=*/std::move(output_schema),
/*num_outputs=*/1),
/*output_schema=*/std::move(output_schema)),
target_field_ids_(std::move(target_field_ids)),
aggs_(std::move(aggs)),
kernels_(std::move(kernels)),
Expand Down Expand Up @@ -159,7 +158,7 @@ class ScalarAggregateNode : public ExecNode {
return Status::OK();
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
Status InputReceived(ExecNode* input, ExecBatch batch) override {
EVENT(span_, "InputReceived", {{"batch.length", batch.length}});
util::tracing::Span span;
START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived",
Expand All @@ -170,25 +169,21 @@ class ScalarAggregateNode : public ExecNode {

auto thread_index = plan_->GetThreadIndex();

if (ErrorIfNotOk(DoConsume(ExecSpan(batch), thread_index))) return;
RETURN_NOT_OK(DoConsume(ExecSpan(batch), thread_index));

if (input_counter_.Increment()) {
ErrorIfNotOk(Finish());
return Finish();
}
return Status::OK();
}

void ErrorReceived(ExecNode* input, Status error) override {
EVENT(span_, "ErrorReceived", {{"error", error.message()}});
DCHECK_EQ(input, inputs_[0]);
outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
Status InputFinished(ExecNode* input, int total_batches) override {
EVENT(span_, "InputFinished", {{"batches.length", total_batches}});
DCHECK_EQ(input, inputs_[0]);
if (input_counter_.SetTotal(total_batches)) {
ErrorIfNotOk(Finish());
return Finish();
}
return Status::OK();
}

Status StartProducing() override {
Expand All @@ -197,8 +192,7 @@ class ScalarAggregateNode : public ExecNode {
{"node.detail", ToString()},
{"node.kind", kind_name()}});
// Scalar aggregates will only output a single batch
outputs_[0]->InputFinished(this, 1);
return Status::OK();
return output_->InputFinished(this, 1);
}

void PauseProducing(ExecNode* output, int32_t counter) override {
Expand All @@ -209,18 +203,7 @@ class ScalarAggregateNode : public ExecNode {
inputs_[0]->ResumeProducing(this, counter);
}

void StopProducing(ExecNode* output) override {
DCHECK_EQ(output, outputs_[0]);
StopProducing();
}

void StopProducing() override {
EVENT(span_, "StopProducing");
if (input_counter_.Cancel()) {
finished_.MarkFinished();
}
inputs_[0]->StopProducing(this);
}
void Abort() override {}

protected:
std::string ToStringExtra(int indent = 0) const override {
Expand Down Expand Up @@ -251,9 +234,7 @@ class ScalarAggregateNode : public ExecNode {
RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
}

outputs_[0]->InputReceived(this, std::move(batch));
finished_.MarkFinished();
return Status::OK();
return output_->InputReceived(this, std::move(batch));
}

const std::vector<int> target_field_ids_;
Expand All @@ -271,24 +252,19 @@ class GroupByNode : public ExecNode {
std::vector<int> key_field_ids, std::vector<int> agg_src_field_ids,
std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema),
/*num_outputs=*/1),
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)),
ctx_(ctx),
key_field_ids_(std::move(key_field_ids)),
agg_src_field_ids_(std::move(agg_src_field_ids)),
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}

Status Init() override {
RETURN_NOT_OK(ExecNode::Init());
output_task_group_id_ = plan_->RegisterTaskGroup(
[this](size_t, int64_t task_id) {
OutputNthBatch(task_id);
return Status::OK();
},
[this](size_t) {
finished_.MarkFinished();
return Status::OK();
});
[this](size_t, int64_t task_id) { return OutputNthBatch(task_id); },
[](size_t) { return Status::OK(); });
local_states_.resize(plan_->max_concurrency());
return Status::OK();
}

Expand Down Expand Up @@ -475,12 +451,9 @@ class GroupByNode : public ExecNode {
return out_data;
}

void OutputNthBatch(int64_t n) {
// bail if StopProducing was called
if (finished_.is_finished()) return;

Status OutputNthBatch(int64_t n) {
int64_t batch_size = output_batch_size();
outputs_[0]->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));
return output_->InputReceived(this, out_data_.Slice(batch_size * n, batch_size));
}

Status OutputResult() {
Expand All @@ -496,50 +469,36 @@ class GroupByNode : public ExecNode {
ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());

int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size());
outputs_[0]->InputFinished(this, static_cast<int>(num_output_batches));
RETURN_NOT_OK(plan_->StartTaskGroup(output_task_group_id_, num_output_batches));
return Status::OK();
RETURN_NOT_OK(output_->InputFinished(this, static_cast<int>(num_output_batches)));
return plan_->StartTaskGroup(output_task_group_id_, num_output_batches);
}

void InputReceived(ExecNode* input, ExecBatch batch) override {
Status InputReceived(ExecNode* input, ExecBatch batch) override {
EVENT(span_, "InputReceived", {{"batch.length", batch.length}});
util::tracing::Span span;
START_COMPUTE_SPAN_WITH_PARENT(span, span_, "InputReceived",
{{"group_by", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});

// bail if StopProducing was called
if (finished_.is_finished()) return;

DCHECK_EQ(input, inputs_[0]);

if (ErrorIfNotOk(Consume(ExecSpan(batch)))) return;

RETURN_NOT_OK(Consume(ExecSpan(batch)));
if (input_counter_.Increment()) {
ErrorIfNotOk(OutputResult());
return OutputResult();
}
return Status::OK();
}

void ErrorReceived(ExecNode* input, Status error) override {
EVENT(span_, "ErrorReceived", {{"error", error.message()}});

DCHECK_EQ(input, inputs_[0]);

outputs_[0]->ErrorReceived(this, std::move(error));
}

void InputFinished(ExecNode* input, int total_batches) override {
Status InputFinished(ExecNode* input, int total_batches) override {
EVENT(span_, "InputFinished", {{"batches.length", total_batches}});

// bail if StopProducing was called
if (finished_.is_finished()) return;

DCHECK_EQ(input, inputs_[0]);

if (input_counter_.SetTotal(total_batches)) {
ErrorIfNotOk(OutputResult());
return OutputResult();
}
return Status::OK();
}

Status StartProducing() override {
Expand All @@ -548,7 +507,6 @@ class GroupByNode : public ExecNode {
{"node.detail", ToString()},
{"node.kind", kind_name()}});

local_states_.resize(plan_->max_concurrency());
return Status::OK();
}

Expand All @@ -562,15 +520,7 @@ class GroupByNode : public ExecNode {
// Without spillover there is way to handle backpressure in this node
}

void StopProducing(ExecNode* output) override {
EVENT(span_, "StopProducing");
DCHECK_EQ(output, outputs_[0]);

if (input_counter_.Cancel()) finished_.MarkFinished();
inputs_[0]->StopProducing(this);
}

void StopProducing() override { StopProducing(outputs_[0]); }
void Abort() override { input_counter_.Cancel(); }

protected:
std::string ToStringExtra(int indent = 0) const override {
Expand Down
Loading

0 comments on commit 279bf83

Please sign in to comment.