Skip to content

Commit

Permalink
GH-35838: [C++] Backpressure broken in asof join node
Browse files Browse the repository at this point in the history
  • Loading branch information
rtpsw committed Jun 2, 2023
1 parent 018e7d3 commit 7d40e92
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 13 deletions.
7 changes: 4 additions & 3 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,18 +668,19 @@ class InputState {

static Result<std::unique_ptr<InputState>> Make(
size_t index, TolType tolerance, bool must_hash, bool may_rehash,
KeyHasher* key_hasher, ExecNode* node, AsofJoinNode* output,
KeyHasher* key_hasher, ExecNode* input, AsofJoinNode* node,
std::atomic<int32_t>& backpressure_counter,
const std::shared_ptr<arrow::Schema>& schema, const col_index_t time_col_index,
const std::vector<col_index_t>& key_col_index) {
constexpr size_t low_threshold = 4, high_threshold = 8;
std::unique_ptr<BackpressureControl> backpressure_control =
std::make_unique<BackpressureController>(node, output, backpressure_counter);
std::make_unique<BackpressureController>(/*node=*/input, /*output=*/node,
backpressure_counter);
ARROW_ASSIGN_OR_RAISE(auto handler,
BackpressureHandler::Make(low_threshold, high_threshold,
std::move(backpressure_control)));
return std::make_unique<InputState>(index, tolerance, must_hash, may_rehash,
key_hasher, output, std::move(handler), schema,
key_hasher, node, std::move(handler), schema,
time_col_index, key_col_index);
}

Expand Down
103 changes: 93 additions & 10 deletions cpp/src/arrow/acero/asof_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#ifndef NDEBUG
#include "arrow/acero/options_internal.h"
#endif
#include "arrow/acero/map_node.h"
#include "arrow/acero/test_nodes.h"
#include "arrow/acero/test_util_internal.h"
#include "arrow/acero/util.h"
Expand Down Expand Up @@ -1381,18 +1382,88 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size,
ASSERT_OK_AND_ASSIGN(auto r0_batches, make_shift(r0_schema, 1));
ASSERT_OK_AND_ASSIGN(auto r1_batches, make_shift(r1_schema, 2));

Declaration l_src = {
"source", SourceNodeOptions(
l_schema, MakeDelayedGen(l_batches, "0:fast", fast_delay, noisy))};
Declaration r0_src = {
"source", SourceNodeOptions(
r0_schema, MakeDelayedGen(r0_batches, "1:slow", slow_delay, noisy))};
Declaration r1_src = {
"source", SourceNodeOptions(
r1_schema, MakeDelayedGen(r1_batches, "2:fast", fast_delay, noisy))};
struct BackpressureCounters {
int32_t pause_count = 0;
int32_t resume_count = 0;
};

struct BackpressureTestNodeOptions : public ExecNodeOptions {
BackpressureCounters* counters;
};

struct BackpressureTestNode : public MapNode {
BackpressureTestNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
const BackpressureTestNodeOptions& options)
: MapNode(plan, inputs, output_schema), counters(options.counters) {}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "BackpressureTestNode"));
auto bp_options = static_cast<const BackpressureTestNodeOptions&>(options);
return plan->EmplaceNode<BackpressureTestNode>(
plan, inputs, inputs[0]->output_schema(), bp_options);
}

const char* kind_name() const override { return "BackpressureTestNode"; }
Result<ExecBatch> ProcessBatch(ExecBatch batch) override { return batch; }

void PauseProducing(ExecNode* output, int32_t counter) override {
++counters->pause_count;
inputs()[0]->PauseProducing(this, counter);
}
void ResumeProducing(ExecNode* output, int32_t counter) override {
++counters->resume_count;
inputs()[0]->ResumeProducing(this, counter);
}

BackpressureCounters* counters;
};

auto exec_reg = default_exec_factory_registry();
std::string bp_test = "backpressure_test";
if (!exec_reg->GetFactory(bp_test).ok()) {
ASSERT_OK(exec_reg->AddFactory(bp_test, BackpressureTestNode::Make));
}

struct SourceConfig {
std::string name_prefix;
bool is_fast;
std::shared_ptr<Schema> schema;
decltype(l_batches) batches;

std::string name() const { return name_prefix + ";" + (is_fast ? "fast" : "slow"); }
};

// must have at least one fast and one slow
std::vector<SourceConfig> source_configs = {
{"0", true, l_schema, l_batches},
{"1", false, r0_schema, r0_batches},
{"2", true, r1_schema, r1_batches},
};

std::vector<BackpressureCounters> bp_counters(source_configs.size());
std::vector<Declaration> src_decls;
std::vector<std::shared_ptr<BackpressureTestNodeOptions>> bp_options;
std::vector<Declaration::Input> bp_decls;
for (size_t i = 0; i < source_configs.size(); i++) {
const auto& config = source_configs[i];
src_decls.emplace_back(
"source", SourceNodeOptions(
config.schema,
MakeDelayedGen(config.batches, config.name(),
config.is_fast ? fast_delay : slow_delay, noisy)));
bp_options.push_back(std::make_shared<BackpressureTestNodeOptions>());
bp_options.back()->counters = &bp_counters[i];
std::shared_ptr<ExecNodeOptions> options = bp_options.back();
std::vector<Declaration::Input> bp_in = {src_decls.back()};
Declaration bp_decl = {bp_test, bp_in, std::move(options)};
bp_decls.push_back(bp_decl);
}

Declaration asofjoin = {
"asofjoin", {l_src, r0_src, r1_src}, GetRepeatedOptions(3, "time", {"key"}, 1000)};
"asofjoin", bp_decls,
GetRepeatedOptions(source_configs.size(), "time", {"key"}, 1000)};

ASSERT_OK_AND_ASSIGN(std::unique_ptr<RecordBatchReader> batch_reader,
DeclarationToReader(asofjoin, /*use_threads=*/false));
Expand All @@ -1406,6 +1477,18 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size,
total_length += batch->num_rows();
}
ASSERT_EQ(static_cast<int64_t>(num_batches * batch_size), total_length);

std::unordered_map<bool, BackpressureCounters> counters_by_is_fast;
for (size_t i = 0; i < source_configs.size(); i++) {
BackpressureCounters& counters = counters_by_is_fast[source_configs[i].is_fast];
counters.pause_count += bp_counters[i].pause_count;
counters.resume_count += bp_counters[i].resume_count;
}
ASSERT_EQ(counters_by_is_fast.size(), 2);
ASSERT_EQ(counters_by_is_fast[false].pause_count, 0);
ASSERT_EQ(counters_by_is_fast[false].resume_count, 0);
ASSERT_GT(counters_by_is_fast[true].pause_count, 0);
ASSERT_GT(counters_by_is_fast[true].resume_count, 0);
}

TEST(AsofJoinTest, BackpressureWithBatches) {
Expand Down

0 comments on commit 7d40e92

Please sign in to comment.