From 20607a83774dda92abcfe3611bd250749a549bdf Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 7 Aug 2022 07:52:18 -0400 Subject: [PATCH 01/26] AsOfJoin support for integer, floating, and timestamp types --- cpp/src/arrow/compute/exec/asof_join_node.cc | 276 +++++++++++++----- .../arrow/compute/exec/asof_join_node_test.cc | 88 +++++- 2 files changed, 284 insertions(+), 80 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 3da612aa03e41..2cd5644c26123 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -17,9 +17,9 @@ #include #include -#include #include #include +#include #include "arrow/array/builder_primitive.h" #include "arrow/compute/exec/exec_plan.h" @@ -37,14 +37,45 @@ namespace arrow { namespace compute { -// Remove this when multiple keys and/or types is supported -typedef int32_t KeyType; +typedef uint64_t KeyType; +typedef uint64_t TimeType; // Maximum number of tables that can be joined #define MAX_JOIN_TABLES 64 typedef uint64_t row_index_t; typedef int col_index_t; +#define VAL_SIGNED(val, w) \ + static inline uint64_t val(int##w##_t t, uint64_t bias = (uint64_t)1 << (w - 1)) { \ + return t < 0 ? static_cast(t + bias) : static_cast(t) + bias; \ + } + +#define VAL_UNSIGNED(val, w) \ + static inline uint64_t val(uint##w##_t t, uint64_t bias = 0) { \ + return static_cast(t); \ + } + +VAL_SIGNED(time_value, 8) +VAL_SIGNED(time_value, 16) +VAL_SIGNED(time_value, 32) +VAL_SIGNED(time_value, 64) +VAL_UNSIGNED(time_value, 8) +VAL_UNSIGNED(time_value, 16) +VAL_UNSIGNED(time_value, 32) +VAL_UNSIGNED(time_value, 64) + +VAL_SIGNED(key_value, 8) +VAL_SIGNED(key_value, 16) +VAL_SIGNED(key_value, 32) +VAL_SIGNED(key_value, 64) +VAL_UNSIGNED(key_value, 8) +VAL_UNSIGNED(key_value, 16) +VAL_UNSIGNED(key_value, 32) +VAL_UNSIGNED(key_value, 64) + +#undef VAL_SIGNED +#undef VAL_UNSIGNED + /** * Simple implementation for an unbound concurrent queue */ @@ -99,7 +130,7 @@ struct MemoStore { struct Entry { // Timestamp associated with the entry - int64_t time; + TimeType time; // Batch associated with the entry (perf is probably OK for this; batches change // rarely) @@ -111,7 +142,7 @@ struct MemoStore { std::unordered_map entries_; - void Store(const std::shared_ptr& batch, row_index_t row, int64_t time, + void Store(const std::shared_ptr& batch, row_index_t row, TimeType time, KeyType key) { auto& e = entries_[key]; // that we can do this assignment optionally, is why we @@ -128,7 +159,7 @@ struct MemoStore { return util::optional(&e->second); } - void RemoveEntriesWithLesserTime(int64_t ts) { + void RemoveEntriesWithLesserTime(TimeType ts) { for (auto e = entries_.begin(); e != entries_.end();) if (e->second.time < ts) e = entries_.erase(e); @@ -148,7 +179,9 @@ class InputState { : queue_(), schema_(schema), time_col_index_(schema->GetFieldIndex(time_col_name)), - key_col_index_(schema->GetFieldIndex(key_col_name)) {} + key_col_index_(schema->GetFieldIndex(key_col_name)), + time_type_id_(schema_->fields()[time_col_index_]->type()->id()), + key_type_id_(schema_->fields()[key_col_index_]->type()->id()) {} col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); @@ -184,18 +217,48 @@ class InputState { return queue_.UnsyncFront(); } +#define LATEST_VAL_CASE(id, val) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + using CType = typename TypeTraits::CType; \ + return val(data->GetValues(1)[latest_ref_row_]); \ + } + KeyType GetLatestKey() const { - return queue_.UnsyncFront() - ->column_data(key_col_index_) - ->GetValues(1)[latest_ref_row_]; + auto data = queue_.UnsyncFront()->column_data(key_col_index_); + switch (key_type_id_) { + LATEST_VAL_CASE(INT8, key_value) + LATEST_VAL_CASE(INT16, key_value) + LATEST_VAL_CASE(INT32, key_value) + LATEST_VAL_CASE(INT64, key_value) + LATEST_VAL_CASE(UINT8, key_value) + LATEST_VAL_CASE(UINT16, key_value) + LATEST_VAL_CASE(UINT32, key_value) + LATEST_VAL_CASE(UINT64, key_value) + default: + return 0; // cannot happen + } } - int64_t GetLatestTime() const { - return queue_.UnsyncFront() - ->column_data(time_col_index_) - ->GetValues(1)[latest_ref_row_]; + TimeType GetLatestTime() const { + auto data = queue_.UnsyncFront()->column_data(time_col_index_); + switch (time_type_id_) { + LATEST_VAL_CASE(INT8, time_value) + LATEST_VAL_CASE(INT16, time_value) + LATEST_VAL_CASE(INT32, time_value) + LATEST_VAL_CASE(INT64, time_value) + LATEST_VAL_CASE(UINT8, time_value) + LATEST_VAL_CASE(UINT16, time_value) + LATEST_VAL_CASE(UINT32, time_value) + LATEST_VAL_CASE(UINT64, time_value) + LATEST_VAL_CASE(TIMESTAMP, time_value) + default: + return 0; // cannot happen + } } +#undef LATEST_VAL_CASE + bool Finished() const { return batches_processed_ == total_batches_; } bool Advance() { @@ -222,28 +285,25 @@ class InputState { // latest_time and latest_ref_row to the value that immediately pass the // specified timestamp. // Returns true if updates were made, false if not. - bool AdvanceAndMemoize(int64_t ts) { + bool AdvanceAndMemoize(TimeType ts) { // Advance the right side row index until we reach the latest right row (for each key) // for the given left timestamp. // Check if already updated for TS (or if there is no latest) if (Empty()) return false; // can't advance if empty - auto latest_time = GetLatestTime(); - if (latest_time > ts) return false; // already advanced // Not updated. Try to update and possibly advance. bool updated = false; do { - latest_time = GetLatestTime(); + auto latest_time = GetLatestTime(); // if Advance() returns true, then the latest_ts must also be valid // Keep advancing right table until we hit the latest row that has // timestamp <= ts. This is because we only need the latest row for the // match given a left ts. - if (latest_time <= ts) { - memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); - } else { + if (latest_time > ts) { break; // hit a future timestamp -- done updating for now } + memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); updated = true; } while (Advance()); return updated; @@ -261,7 +321,7 @@ class InputState { return memo_.GetEntryForKey(key); } - util::optional GetMemoTimeForKey(KeyType key) { + util::optional GetMemoTimeForKey(KeyType key) { auto r = GetMemoEntryForKey(key); if (r.has_value()) { return (*r)->time; @@ -270,7 +330,7 @@ class InputState { } } - void RemoveMemoEntriesWithLesserTime(int64_t ts) { + void RemoveMemoEntriesWithLesserTime(TimeType ts) { memo_.RemoveEntriesWithLesserTime(ts); } @@ -295,6 +355,10 @@ class InputState { col_index_t time_col_index_; // Index of the key col col_index_t key_col_index_; + // Type id of the time column + Type::type time_type_id_; + // Type id of the key column + Type::type key_type_id_; // Index of the latest row reference within; if >0 then queue_ cannot be empty // Must be < queue_.front()->num_rows() if queue_ is non-empty row_index_t latest_ref_row_ = 0; @@ -336,7 +400,7 @@ class CompositeReferenceTable { // Adds the latest row from the input state as a new composite reference row // - LHS must have a valid key,timestep,and latest rows // - RHS must have valid data memo'ed for the key - void Emplace(std::vector>& in, int64_t tolerance) { + void Emplace(std::vector>& in, TimeType tolerance) { DCHECK_EQ(in.size(), n_tables_); // Get the LHS key @@ -347,7 +411,7 @@ class CompositeReferenceTable { DCHECK(!in[0]->Empty()); const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); row_index_t lhs_latest_row = in[0]->GetLatestRow(); - int64_t lhs_latest_time = in[0]->GetLatestTime(); + TimeType lhs_latest_time = in[0]->GetLatestTime(); if (0 == lhs_latest_row) { // On the first row of the batch, we resize the destination. // The destination size is dictated by the size of the LHS batch. @@ -407,29 +471,34 @@ class CompositeReferenceTable { DCHECK_EQ(src_field->name(), dst_field->name()); const auto& field_type = src_field->type(); - if (field_type->Equals(arrow::int32())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::int64())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::float32())) { - ARROW_ASSIGN_OR_RAISE(arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::float64())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else { - ARROW_RETURN_NOT_OK( - Status::Invalid("Unsupported data type: ", src_field->name())); +#define ASOFJOIN_MATERIALIZE_CASE(id) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + ARROW_ASSIGN_OR_RAISE( \ + arrays.at(i_dst_col), \ + MaterializeColumn(memory_pool, field_type, i_table, i_src_col)); \ + break; \ + } + + switch (field_type->id()) { + ASOFJOIN_MATERIALIZE_CASE(INT8) + ASOFJOIN_MATERIALIZE_CASE(INT16) + ASOFJOIN_MATERIALIZE_CASE(INT32) + ASOFJOIN_MATERIALIZE_CASE(INT64) + ASOFJOIN_MATERIALIZE_CASE(UINT8) + ASOFJOIN_MATERIALIZE_CASE(UINT16) + ASOFJOIN_MATERIALIZE_CASE(UINT32) + ASOFJOIN_MATERIALIZE_CASE(UINT64) + ASOFJOIN_MATERIALIZE_CASE(FLOAT) + ASOFJOIN_MATERIALIZE_CASE(DOUBLE) + ASOFJOIN_MATERIALIZE_CASE(TIMESTAMP) + default: + return Status::Invalid("Unsupported data type ", + src_field->type()->ToString(), " for field ", + src_field->name()); } + +#undef ASOFJOIN_MATERIALIZE_CASE } } } @@ -459,11 +528,13 @@ class CompositeReferenceTable { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } - template - Result> MaterializePrimitiveColumn(MemoryPool* memory_pool, - size_t i_table, - col_index_t i_col) { - Builder builder(memory_pool); + template ::BuilderType, + class PrimitiveType = typename TypeTraits::CType> + Result> MaterializeColumn(MemoryPool* memory_pool, + const std::shared_ptr& type, + size_t i_table, col_index_t i_col) { + ARROW_ASSIGN_OR_RAISE(auto a_builder, MakeBuilder(type, memory_pool)); + Builder& builder = *checked_cast(a_builder.get()); ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size())); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; @@ -495,7 +566,7 @@ class AsofJoinNode : public ExecNode { bool IsUpToDateWithLhsRow() const { auto& lhs = *state_[0]; if (lhs.Empty()) return false; // can't proceed if nothing on the LHS - int64_t lhs_ts = lhs.GetLatestTime(); + TimeType lhs_ts = lhs.GetLatestTime(); for (size_t i = 1; i < state_.size(); ++i) { auto& rhs = *state_[i]; if (!rhs.Finished()) { @@ -531,7 +602,7 @@ class AsofJoinNode : public ExecNode { // the LHS and adding joined row to rows_ (done by Emplace). Finally, // input batches that are no longer needed are removed to free up memory. if (IsUpToDateWithLhsRow()) { - dst.Emplace(state_, options_.tolerance); + dst.Emplace(state_, time_value(options_.tolerance, 0)); if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch } else { if (!any_rhs_advanced) break; // need to wait for new data @@ -542,7 +613,7 @@ class AsofJoinNode : public ExecNode { if (!lhs.Empty()) { for (size_t i = 1; i < state_.size(); ++i) { state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - - options_.tolerance); + time_value(options_.tolerance, 0)); } } @@ -610,6 +681,16 @@ class AsofJoinNode : public ExecNode { process_thread_.join(); } + static bool find_type(std::unordered_set> type_set, + const std::shared_ptr& type) { + for (auto ty : type_set) { + if (*ty.get() == *type.get()) { + return true; + } + } + return false; + } + static arrow::Result> MakeOutputSchema( const std::vector& inputs, const AsofJoinNodeOptions& options) { std::vector> fields; @@ -617,6 +698,8 @@ class AsofJoinNode : public ExecNode { const auto& on_field_name = *options.on_key.name(); const auto& by_field_name = *options.by_key.name(); + const DataType* on_key_type = NULLPTR; + const DataType* by_key_type = NULLPTR; // Take all non-key, non-time RHS fields for (size_t j = 0; j < inputs.size(); ++j) { const auto& input_schema = inputs[j]->output_schema(); @@ -627,27 +710,47 @@ class AsofJoinNode : public ExecNode { return Status::Invalid("Missing join key on table ", j); } + const auto& on_field_type = input_schema->fields()[on_field_ix]->type(); + const auto& by_field_type = input_schema->fields()[by_field_ix]->type(); + if (on_key_type == NULLPTR) { + on_key_type = on_field_type.get(); + } else if (*on_key_type != *on_field_type) { + return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", + *on_field_type, " for field ", on_field_name, " in input ", + j); + } + if (by_key_type == NULLPTR) { + by_key_type = by_field_type.get(); + } else if (*by_key_type != *by_field_type) { + return Status::Invalid("Expected on-key type ", *by_key_type, " but got ", + *by_field_type, " for field ", by_field_name, " in input ", + j); + } + for (int i = 0; i < input_schema->num_fields(); ++i) { const auto field = input_schema->field(i); if (field->name() == on_field_name) { - if (kSupportedOnTypes_.find(field->type()) == kSupportedOnTypes_.end()) { - return Status::Invalid("Unsupported type for on key: ", field->name()); + if (!find_type(kSupportedOnTypes_, field->type())) { + return Status::Invalid("Unsupported type ", field->type()->ToString(), + " for on-key ", field->name()); } // Only add on field from the left table if (j == 0) { fields.push_back(field); } } else if (field->name() == by_field_name) { - if (kSupportedByTypes_.find(field->type()) == kSupportedByTypes_.end()) { - return Status::Invalid("Unsupported type for by key: ", field->name()); + if (!find_type(kSupportedByTypes_, field->type())) { + return Status::Invalid("Unsupported type ", field->type()->ToString(), + " for by-key ", field->name()); } // Only add by field from the left table if (j == 0) { fields.push_back(field); } } else { - if (kSupportedDataTypes_.find(field->type()) == kSupportedDataTypes_.end()) { - return Status::Invalid("Unsupported data type: ", field->name()); + if (!find_type(kSupportedDataTypes_, field->type())) { + return Status::Invalid("Unsupported data type ", field->type()->ToString(), + " for field ", field->name()); } fields.push_back(field); @@ -718,9 +821,9 @@ class AsofJoinNode : public ExecNode { arrow::Future<> finished() override { return finished_; } private: - static const std::set> kSupportedOnTypes_; - static const std::set> kSupportedByTypes_; - static const std::set> kSupportedDataTypes_; + static const std::unordered_set> kSupportedOnTypes_; + static const std::unordered_set> kSupportedByTypes_; + static const std::unordered_set> kSupportedDataTypes_; arrow::Future<> finished_; // InputStates @@ -760,10 +863,47 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, } // Currently supported types -const std::set> AsofJoinNode::kSupportedOnTypes_ = {int64()}; -const std::set> AsofJoinNode::kSupportedByTypes_ = {int32()}; -const std::set> AsofJoinNode::kSupportedDataTypes_ = { - int32(), int64(), float32(), float64()}; +const std::unordered_set> AsofJoinNode::kSupportedOnTypes_ = { + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC")}; +const std::unordered_set> AsofJoinNode::kSupportedByTypes_ = { + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC")}; +const std::unordered_set> AsofJoinNode::kSupportedDataTypes_ = { + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC"), + float32(), + float64()}; namespace internal { void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 8b993764abe7f..c780f4b5d4200 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -74,23 +74,29 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, /*same_chunk_layout=*/true, /*flatten=*/true); } -void DoRunBasicTest(const std::vector& l_data, - const std::vector& r0_data, - const std::vector& r1_data, - const std::vector& exp_data, int64_t tolerance) { +struct BasicTestTypes { + std::shared_ptr time, key, l_val, r0_val, r1_val; +}; + +void DoRunBasicTestTypes(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_data, + int64_t tolerance, BasicTestTypes basic_test_types) { + const BasicTestTypes& b = basic_test_types; auto l_schema = - schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); + schema({field("time", b.time), field("key", b.key), field("l_v0", b.l_val)}); auto r0_schema = - schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())}); + schema({field("time", b.time), field("key", b.key), field("r0_v0", b.r0_val)}); auto r1_schema = - schema({field("time", int64()), field("key", int32()), field("r1_v0", float32())}); + schema({field("time", b.time), field("key", b.key), field("r1_v0", b.r1_val)}); auto exp_schema = schema({ - field("time", int64()), - field("key", int32()), - field("l_v0", float64()), - field("r0_v0", float64()), - field("r1_v0", float32()), + field("time", b.time), + field("key", b.key), + field("l_v0", b.l_val), + field("r0_v0", b.r0_val), + field("r1_v0", b.r1_val), }); // Test three table join @@ -103,6 +109,64 @@ void DoRunBasicTest(const std::vector& l_data, tolerance); } +static inline void init_types( + const std::vector>& all_types, + std::vector>& types, + std::function)> type_cond) { + if (types.size() == 0) { + for (auto type : all_types) { + if (type_cond(type)) { + types.push_back(type); + } + } + } +} + +void DoRunBasicTest(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_data, int64_t tolerance, + std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + std::vector> all_types = {int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC"), + float32(), + float64()}; + using T = const std::shared_ptr; + init_types(all_types, time_types, + [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); + init_types(all_types, key_types, [](T& t) { return is_integer(t->id()); }); + init_types(all_types, l_types, + [](T& t) { return t->byte_width() > 1 && t->id() != Type::TIMESTAMP; }); + init_types(all_types, r0_types, [](T& t) { return is_floating(t->id()); }); + init_types(all_types, r1_types, [](T& t) { return is_floating(t->id()); }); + for (auto time_type : time_types) { + for (auto key_type : key_types) { + for (auto l_type : l_types) { + for (auto r0_type : r0_types) { + for (auto r1_type : r1_types) { + DoRunBasicTestTypes(l_data, r0_data, r1_data, exp_data, tolerance, + {time_type, key_type, float64(), float64(), float32()}); + } + } + } + } + } +} + void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); From f4c450d02d2fae1dd0cedff95453d0a863786a52 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 7 Aug 2022 13:02:47 -0400 Subject: [PATCH 02/26] fix basic test coverage --- cpp/src/arrow/compute/exec/asof_join_node_test.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index c780f4b5d4200..cccb718e27eb8 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -149,8 +149,7 @@ void DoRunBasicTest(const std::vector& l_data, init_types(all_types, time_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); init_types(all_types, key_types, [](T& t) { return is_integer(t->id()); }); - init_types(all_types, l_types, - [](T& t) { return t->byte_width() > 1 && t->id() != Type::TIMESTAMP; }); + init_types(all_types, l_types, [](T& t) { return is_floating(t->id()); }); init_types(all_types, r0_types, [](T& t) { return is_floating(t->id()); }); init_types(all_types, r1_types, [](T& t) { return is_floating(t->id()); }); for (auto time_type : time_types) { @@ -159,7 +158,7 @@ void DoRunBasicTest(const std::vector& l_data, for (auto r0_type : r0_types) { for (auto r1_type : r1_types) { DoRunBasicTestTypes(l_data, r0_data, r1_data, exp_data, tolerance, - {time_type, key_type, float64(), float64(), float32()}); + {time_type, key_type, l_type, r0_type, r1_type}); } } } From 0dc0971999d591d9f3cf1ca32c773bb5d7f974b7 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 8 Aug 2022 02:02:18 -0400 Subject: [PATCH 03/26] test more types --- .../arrow/compute/exec/asof_join_node_test.cc | 121 +++++++++--------- 1 file changed, 61 insertions(+), 60 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index cccb718e27eb8..bd0b157c5cfd3 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -149,7 +149,8 @@ void DoRunBasicTest(const std::vector& l_data, init_types(all_types, time_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); init_types(all_types, key_types, [](T& t) { return is_integer(t->id()); }); - init_types(all_types, l_types, [](T& t) { return is_floating(t->id()); }); + init_types(all_types, l_types, + [](T& t) { return t->byte_width() > 1 && t->id() != Type::TIMESTAMP; }); init_types(all_types, r0_types, [](T& t) { return is_floating(t->id()); }); init_types(all_types, r1_types, [](T& t) { return is_floating(t->id()); }); for (auto time_type : time_types) { @@ -189,94 +190,94 @@ class AsofJoinTest : public testing::Test {}; TEST(AsofJoinTest, TestBasic1) { // Single key, single batch DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0], [1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0]])"}, - /*r1*/ {R"([[1000, 1, 101.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"}, 1000); + /*l*/ {R"([[0, 1, 1], [1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11]])"}, + /*r1*/ {R"([[1000, 1, 101]])"}, + /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, 1000); } TEST(AsofJoinTest, TestBasic2) { // Single key, multiple batches DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"}, - /*r1*/ {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); + /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11]])", R"([[1000, 1, 12]])"}, + /*r1*/ {R"([[0, 1, 101]])", R"([[1000, 1, 102]])"}, + /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } TEST(AsofJoinTest, TestBasic3) { // Single key, multiple left batches, single right batches DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0], [1000, 1, 12.0]])"}, - /*r1*/ {R"([[0, 1, 101.0], [1000, 1, 102.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); + /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11], [1000, 1, 12]])"}, + /*r1*/ {R"([[0, 1, 101], [1000, 1, 102]])"}, + /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } TEST(AsofJoinTest, TestBasic4) { // Multi key, multiple batches, misaligned batches DoRunBasicTest( /*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, 1001], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } TEST(AsofJoinTest, TestBasic5) { // Multi key, multiple batches, misaligned batches, smaller tolerance DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, null], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 500); } TEST(AsofJoinTest, TestBasic6) { // Multi key, multiple batches, misaligned batches, zero tolerance DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null, null], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null, null]])"}, + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, null], [1500, 1, 3, null, null], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"}, 0); } TEST(AsofJoinTest, TestEmpty1) { // Empty left batch DoRunBasicTest(/*l*/ - {R"([])", R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([])", R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp*/ - {R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } @@ -285,11 +286,11 @@ TEST(AsofJoinTest, TestEmpty2) { DoRunBasicTest(/*l*/ {R"([])"}, /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp*/ {R"([])"}, 1000); } @@ -297,32 +298,32 @@ TEST(AsofJoinTest, TestEmpty2) { TEST(AsofJoinTest, TestEmpty3) { // Empty right batch DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ - {R"([])", R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([])", R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp*/ - {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } TEST(AsofJoinTest, TestEmpty4) { // Empty right input DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ {R"([])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp*/ - {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, null, 1002.0]])", - R"([[2000, 1, 4.0, null, 103.0], [2000, 2, 24.0, null, 1002.0]])"}, + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, null, 1002]])", + R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"}, 1000); } From c768376a7c1cbdb499684b2273cd19e223457ec2 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 9 Aug 2022 15:33:22 -0400 Subject: [PATCH 04/26] AsofJoinNode error checks, test cases, docs --- cpp/src/arrow/compute/exec/asof_join_node.cc | 112 +++++++++++------ .../arrow/compute/exec/asof_join_node_test.cc | 116 +++++++++++++++++- cpp/src/arrow/compute/exec/hash_join.cc | 1 - cpp/src/arrow/compute/exec/options.h | 10 +- 4 files changed, 187 insertions(+), 52 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 2cd5644c26123..1da8db7e32c83 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -175,11 +175,11 @@ class InputState { public: InputState(const std::shared_ptr& schema, - const std::string& time_col_name, const std::string& key_col_name) + const col_index_t time_col_index, const col_index_t key_col_index) : queue_(), schema_(schema), - time_col_index_(schema->GetFieldIndex(time_col_name)), - key_col_index_(schema->GetFieldIndex(key_col_name)), + time_col_index_(time_col_index), + key_col_index_(key_col_index), time_type_id_(schema_->fields()[time_col_index_]->type()->id()), key_type_id_(schema_->fields()[key_col_index_]->type()->id()) {} @@ -602,7 +602,7 @@ class AsofJoinNode : public ExecNode { // the LHS and adding joined row to rows_ (done by Emplace). Finally, // input batches that are no longer needed are removed to free up memory. if (IsUpToDateWithLhsRow()) { - dst.Emplace(state_, time_value(options_.tolerance, 0)); + dst.Emplace(state_, tolerance_); if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch } else { if (!any_rhs_advanced) break; // need to wait for new data @@ -612,8 +612,7 @@ class AsofJoinNode : public ExecNode { // Prune memo entries that have expired (to bound memory consumption) if (!lhs.Empty()) { for (size_t i = 1; i < state_.size(); ++i) { - state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - - time_value(options_.tolerance, 0)); + state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - tolerance_); } } @@ -673,8 +672,9 @@ class AsofJoinNode : public ExecNode { public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema); + const std::vector& on_key_col_indices, + const std::vector& by_key_col_indices, + TimeType tolerance, std::shared_ptr output_schema); virtual ~AsofJoinNode() { process_.Push(false); // poison pill @@ -692,56 +692,57 @@ class AsofJoinNode : public ExecNode { } static arrow::Result> MakeOutputSchema( - const std::vector& inputs, const AsofJoinNodeOptions& options) { + const std::vector& inputs, + const std::vector& on_key_col_indices, + const std::vector& by_key_col_indices) { std::vector> fields; - const auto& on_field_name = *options.on_key.name(); - const auto& by_field_name = *options.by_key.name(); - const DataType* on_key_type = NULLPTR; const DataType* by_key_type = NULLPTR; // Take all non-key, non-time RHS fields for (size_t j = 0; j < inputs.size(); ++j) { const auto& input_schema = inputs[j]->output_schema(); - const auto& on_field_ix = input_schema->GetFieldIndex(on_field_name); - const auto& by_field_ix = input_schema->GetFieldIndex(by_field_name); + const auto& on_field_ix = on_key_col_indices[j]; + const auto& by_field_ix = by_key_col_indices[j]; if ((on_field_ix == -1) | (by_field_ix == -1)) { return Status::Invalid("Missing join key on table ", j); } - const auto& on_field_type = input_schema->fields()[on_field_ix]->type(); - const auto& by_field_type = input_schema->fields()[by_field_ix]->type(); + const auto& on_field = input_schema->fields()[on_field_ix]; + const auto& by_field = input_schema->fields()[by_field_ix]; + const auto& on_field_type = on_field->type(); + const auto& by_field_type = by_field->type(); if (on_key_type == NULLPTR) { on_key_type = on_field_type.get(); } else if (*on_key_type != *on_field_type) { return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", - *on_field_type, " for field ", on_field_name, " in input ", - j); + *on_field_type, " for field ", on_field->name(), + " in input ", j); } if (by_key_type == NULLPTR) { by_key_type = by_field_type.get(); } else if (*by_key_type != *by_field_type) { return Status::Invalid("Expected on-key type ", *by_key_type, " but got ", - *by_field_type, " for field ", by_field_name, " in input ", - j); + *by_field_type, " for field ", by_field->name(), + " in input ", j); } for (int i = 0; i < input_schema->num_fields(); ++i) { const auto field = input_schema->field(i); - if (field->name() == on_field_name) { + if (i == on_field_ix) { if (!find_type(kSupportedOnTypes_, field->type())) { - return Status::Invalid("Unsupported type ", field->type()->ToString(), - " for on-key ", field->name()); + return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", + field->type()->ToString()); } // Only add on field from the left table if (j == 0) { fields.push_back(field); } - } else if (field->name() == by_field_name) { + } else if (i == by_field_ix) { if (!find_type(kSupportedByTypes_, field->type())) { - return Status::Invalid("Unsupported type ", field->type()->ToString(), - " for by-key ", field->name()); + return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", + field->type()->ToString()); } // Only add by field from the left table if (j == 0) { @@ -749,8 +750,8 @@ class AsofJoinNode : public ExecNode { } } else { if (!find_type(kSupportedDataTypes_, field->type())) { - return Status::Invalid("Unsupported data type ", field->type()->ToString(), - " for field ", field->name()); + return Status::Invalid("Unsupported type for field ", field->name(), " : ", + field->type()->ToString()); } fields.push_back(field); @@ -765,17 +766,46 @@ class AsofJoinNode : public ExecNode { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; const auto& join_options = checked_cast(options); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, - MakeOutputSchema(inputs, join_options)); + if (join_options.tolerance < 0) { + return Status::Invalid("AsOfJoin tolerance must be non-negative but is ", + join_options.tolerance); + } std::vector input_labels(inputs.size()); - input_labels[0] = "left"; - for (size_t i = 1; i < inputs.size(); ++i) { - input_labels[i] = "right_" + std::to_string(i); + std::vector on_key_col_indices(inputs.size()); + std::vector by_key_col_indices(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i); + const auto& input_schema = inputs[i]->output_schema(); + +#define ASOFJOIN_KEY_MATCH(k) \ + auto k##_key_match_res = join_options.k##_key.FindOne(*input_schema); \ + if (!k##_key_match_res.ok()) { \ + return Status::Invalid("Bad join key on table : ", \ + k##_key_match_res.status().message()); \ + } \ + auto k##_key_match = k##_key_match_res.ValueOrDie(); \ + if (k##_key_match.indices().size() != 1) { \ + return Status::Invalid("AsOfJoinNode does not support a nested " #k "-key ", \ + join_options.k##_key.ToString()); \ + } \ + k##_key_col_indices[i] = k##_key_match.indices()[0]; + + ASOFJOIN_KEY_MATCH(on) + ASOFJOIN_KEY_MATCH(by) + +#undef ASOFJOIN_KEY_MATCH } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, + MakeOutputSchema(inputs, on_key_col_indices, + by_key_col_indices)); + return plan->EmplaceNode(plan, inputs, std::move(input_labels), - join_options, std::move(output_schema)); + std::move(on_key_col_indices), + std::move(by_key_col_indices), + time_value(join_options.tolerance, 0), + std::move(output_schema)); } const char* kind_name() const override { return "AsofJoinNode"; } @@ -830,7 +860,7 @@ class AsofJoinNode : public ExecNode { // Each input state correponds to an input table std::vector> state_; std::mutex gate_; - AsofJoinNodeOptions options_; + TimeType tolerance_; // Queue for triggering processing of a given input // (a false value is a poison pill) @@ -844,17 +874,19 @@ class AsofJoinNode : public ExecNode { AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema) + const std::vector& on_key_col_indices, + const std::vector& by_key_col_indices, + TimeType tolerance, std::shared_ptr output_schema) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), - options_(join_options), + tolerance_(tolerance), process_(), process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { - for (size_t i = 0; i < inputs.size(); ++i) + for (size_t i = 0; i < inputs.size(); ++i) { state_.push_back(::arrow::internal::make_unique( - inputs[i]->output_schema(), *options_.on_key.name(), *options_.by_key.name())); + inputs[i]->output_schema(), on_key_col_indices[i], by_key_col_indices[i])); + } col_index_t dst_offset = 0; for (auto& state : state_) dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index bd0b157c5cfd3..614b0e3848391 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -167,22 +167,81 @@ void DoRunBasicTest(const std::vector& l_data, } } -void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema) { +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, + FieldRef on_key, FieldRef by_key, int64_t tolerance, + const std::string& expected_error_str) { BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); - AsofJoinNodeOptions join_options("time", "key", 0); + AsofJoinNodeOptions join_options(on_key, by_key, tolerance); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); - ASSERT_RAISES(Invalid, join.AddToPlan(plan.get())); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr(expected_error_str), join.AddToPlan(plan.get())); +} + +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, + int64_t tolerance, const std::string& expected_error_str) { + DoRunInvalidPlanTest(l_schema, r_schema, "time", "key", tolerance, expected_error_str); +} + +void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Unsupported type for "); +} + +void DoRunInvalidToleranceTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, -1, + "AsOfJoin tolerance must be non-negative but is "); +} + +void DoRunMissingKeysTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No match"); +} + +void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, "invalid_time", "key", 0, + "Bad join key on table : No match"); +} + +void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, "time", "invalid_key", 0, + "Bad join key on table : No match"); +} + +void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, {0, "time"}, "key", 0, + "Bad join key on table : No match"); +} + +void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, "time", {0, "key"}, 0, + "Bad join key on table : No match"); +} + +void DoRunAmbiguousOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); +} + +void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); } class AsofJoinTest : public testing::Test {}; @@ -359,16 +418,61 @@ TEST(AsofJoinTest, TestUnsupportedDatatype) { } TEST(AsofJoinTest, TestMissingKeys) { - DoRunInvalidTypeTest( + DoRunMissingKeysTest( schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}), schema( {field("time1", int64()), field("key", int32()), field("r0_v0", float64())})); - DoRunInvalidTypeTest( + DoRunMissingKeysTest( schema({field("time", int64()), field("key1", int32()), field("l_v0", float64())}), schema( {field("time", int64()), field("key1", int32()), field("r0_v0", float64())})); } +TEST(AsofJoinTest, TestUnsupportedTolerance) { + // Utf8 is unsupported + DoRunInvalidToleranceTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestMissingOnKey) { + DoRunMissingOnKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestMissingByKey) { + DoRunMissingByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestNestedOnKey) { + DoRunNestedOnKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestNestedByKey) { + DoRunNestedByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestAmbiguousOnKey) { + DoRunAmbiguousOnKeyTest( + schema({field("time", int64()), field("time", int64()), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestAmbiguousByKey) { + DoRunAmbiguousByKeyTest( + schema({field("time", int64()), field("key", int64()), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 07a3083fb92ad..da27d15b105b1 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -26,7 +26,6 @@ #include #include "arrow/compute/exec/hash_join_dict.h" -#include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/task_util.h" #include "arrow/compute/kernels/row_encoder.h" #include "arrow/compute/row/encode_internal.h" diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 4a0cd602efb54..1e7ba01b2f2b3 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -400,18 +400,18 @@ class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { /// \brief "on" key for the join. Each /// - /// All inputs tables must be sorted by the "on" key. Inexact - /// match is used on the "on" key. i.e., a row is considiered match iff + /// All inputs tables must be sorted by the "on" key. Must be a single field of a common + /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff /// left_on - tolerance <= right_on <= left_on. - /// Currently, "on" key must be an int64 field + /// Currently, the "on" key must be of an integer or timestamp type FieldRef on_key; /// \brief "by" key for the join. /// /// All input tables must have the "by" key. Exact equality /// is used for the "by" key. - /// Currently, the "by" key must be an int32 field + /// Currently, the "by" key must be of an integer or timestamp type FieldRef by_key; - /// Tolerance for inexact "on" key matching + /// Tolerance for inexact "on" key matching. Must be non-negative. int64_t tolerance; }; From cde1d258107cd98d294b15fa6756edd114356dde Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 15 Aug 2022 06:32:34 -0400 Subject: [PATCH 05/26] AsofJoinNode multi-key support --- cpp/src/arrow/compute/exec/asof_join_node.cc | 302 ++++++--- .../arrow/compute/exec/asof_join_node_test.cc | 581 +++++++++++------- cpp/src/arrow/compute/exec/options.h | 16 +- cpp/src/arrow/compute/light_array.cc | 6 + cpp/src/arrow/compute/light_array.h | 11 + 5 files changed, 607 insertions(+), 309 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 1da8db7e32c83..2319ba8249910 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -23,9 +23,11 @@ #include "arrow/array/builder_primitive.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/util.h" +#include "arrow/compute/light_array.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" @@ -37,44 +39,41 @@ namespace arrow { namespace compute { +template +inline typename T::const_iterator std_find(const T& container, const V& val) { + return std::find(container.begin(), container.end(), val); +} + +template +inline bool std_has(const T& container, const V& val) { + return container.end() != std_find(container, val); +} + typedef uint64_t KeyType; typedef uint64_t TimeType; +typedef uint64_t HashType; // Maximum number of tables that can be joined #define MAX_JOIN_TABLES 64 typedef uint64_t row_index_t; typedef int col_index_t; +typedef std::vector vec_col_index_t; -#define VAL_SIGNED(val, w) \ - static inline uint64_t val(int##w##_t t, uint64_t bias = (uint64_t)1 << (w - 1)) { \ - return t < 0 ? static_cast(t + bias) : static_cast(t) + bias; \ - } - -#define VAL_UNSIGNED(val, w) \ - static inline uint64_t val(uint##w##_t t, uint64_t bias = 0) { \ - return static_cast(t); \ - } - -VAL_SIGNED(time_value, 8) -VAL_SIGNED(time_value, 16) -VAL_SIGNED(time_value, 32) -VAL_SIGNED(time_value, 64) -VAL_UNSIGNED(time_value, 8) -VAL_UNSIGNED(time_value, 16) -VAL_UNSIGNED(time_value, 32) -VAL_UNSIGNED(time_value, 64) +template ::value, bool> = true> +static inline uint64_t norm_value(T t) { + uint64_t bias = std::is_signed::value ? (uint64_t)1 << (8 * sizeof(T) - 1) : 0; + return t < 0 ? static_cast(t + bias) : static_cast(t); +} -VAL_SIGNED(key_value, 8) -VAL_SIGNED(key_value, 16) -VAL_SIGNED(key_value, 32) -VAL_SIGNED(key_value, 64) -VAL_UNSIGNED(key_value, 8) -VAL_UNSIGNED(key_value, 16) -VAL_UNSIGNED(key_value, 32) -VAL_UNSIGNED(key_value, 64) +template ::value, bool> = true> +static inline uint64_t time_value(T t) { + return norm_value(t); +} -#undef VAL_SIGNED -#undef VAL_UNSIGNED +template ::value, bool> = true> +static inline uint64_t key_value(T t) { + return norm_value(t); +} /** * Simple implementation for an unbound concurrent queue @@ -168,20 +167,84 @@ struct MemoStore { } }; +// a specialized higher-performance variation of Hashing64 logic +class KeyHasher { + static constexpr int kMiniBatchLength = util::MiniBatch::kMiniBatchLength; + + public: + explicit KeyHasher(const vec_col_index_t& indices) + : indices_(indices), + metadata_(indices.size()), + batch_(NULLPTR), + hashes_(), + ctx_(), + column_arrays_(), + stack_() { + ctx_.stack = &stack_; + column_arrays_.resize(indices.size()); + } + + Status Init(ExecContext* exec_context, const std::shared_ptr& schema) { + ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags(); + const auto& fields = schema->fields(); + for (size_t k = 0; k < metadata_.size(); k++) { + ARROW_ASSIGN_OR_RAISE(metadata_[k], + ColumnMetadataFromDataType(fields[indices_[k]]->type())); + } + return stack_.Init(exec_context->memory_pool(), + 4 * kMiniBatchLength * sizeof(uint32_t)); + } + + const std::vector& HashesFor(const RecordBatch* batch) { + if (batch_ == batch) { + return hashes_; + } + batch_ = NULLPTR; // invalidate cached hashes for batch + size_t batch_length = batch->num_rows(); + hashes_.resize(batch_length); + for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { + int64_t length = std::min(static_cast(batch_length - i), + static_cast(kMiniBatchLength)); + for (size_t k = 0; k < indices_.size(); k++) { + auto array_data = batch->column_data(indices_[k]); + column_arrays_[k] = + ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); + } + Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); + } + batch_ = batch; + return hashes_; + } + + private: + vec_col_index_t indices_; + std::vector metadata_; + const RecordBatch* batch_; + std::vector hashes_; + LightContext ctx_; + std::vector column_arrays_; + util::TempVectorStack stack_; +}; + class InputState { // InputState correponds to an input // Input record batches are queued up in InputState until processed and // turned into output record batches. public: - InputState(const std::shared_ptr& schema, - const col_index_t time_col_index, const col_index_t key_col_index) + InputState(KeyHasher* key_hasher, const std::shared_ptr& schema, + const col_index_t time_col_index, const vec_col_index_t& key_col_index) : queue_(), schema_(schema), time_col_index_(time_col_index), key_col_index_(key_col_index), time_type_id_(schema_->fields()[time_col_index_]->type()->id()), - key_type_id_(schema_->fields()[key_col_index_]->type()->id()) {} + key_type_id_(schema_->num_fields()), + key_hasher_(key_hasher) { + for (size_t k = 0; k < key_col_index_.size(); k++) { + key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); + } + } col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); @@ -197,7 +260,7 @@ class InputState { bool IsTimeOrKeyColumn(col_index_t i) const { DCHECK_LT(i, schema_->num_fields()); - return (i == time_col_index_) || (i == key_col_index_); + return (i == time_col_index_) || std_has(key_col_index_, i); } // Gets the latest row index, assuming the queue isn't empty @@ -225,8 +288,11 @@ class InputState { } KeyType GetLatestKey() const { - auto data = queue_.UnsyncFront()->column_data(key_col_index_); - switch (key_type_id_) { + if (key_hasher_ != NULLPTR) { + return key_hasher_->HashesFor(queue_.UnsyncFront().get())[latest_ref_row_]; + } + auto data = queue_.UnsyncFront()->column_data(key_col_index_[0]); + switch (key_type_id_[0]) { LATEST_VAL_CASE(INT8, key_value) LATEST_VAL_CASE(INT16, key_value) LATEST_VAL_CASE(INT32, key_value) @@ -354,11 +420,13 @@ class InputState { // Index of the time col col_index_t time_col_index_; // Index of the key col - col_index_t key_col_index_; + vec_col_index_t key_col_index_; // Type id of the time column Type::type time_type_id_; // Type id of the key column - Type::type key_type_id_; + std::vector key_type_id_; + // Buffer for key elements + mutable KeyHasher* key_hasher_; // Index of the latest row reference within; if >0 then queue_ cannot be empty // Must be < queue_.front()->num_rows() if queue_ is non-empty row_index_t latest_ref_row_ = 0; @@ -672,15 +740,35 @@ class AsofJoinNode : public ExecNode { public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const std::vector& on_key_col_indices, - const std::vector& by_key_col_indices, - TimeType tolerance, std::shared_ptr output_schema); + const vec_col_index_t& indices_of_on_key, + const std::vector& indices_of_by_key, TimeType tolerance, + std::shared_ptr output_schema); + + Status Init(std::vector> key_hashers) { + key_hashers_.swap(key_hashers); + bool has_kp = key_hashers_.size() > 0; + auto inputs = this->inputs(); + for (size_t i = 0; i < inputs.size(); ++i) { + state_.push_back(::arrow::internal::make_unique( + has_kp ? key_hashers_[i].get() : NULLPTR, inputs[i]->output_schema(), + indices_of_on_key_[i], indices_of_by_key_[i])); + } + + col_index_t dst_offset = 0; + for (auto& state : state_) + dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + + return Status::OK(); + } virtual ~AsofJoinNode() { process_.Push(false); // poison pill process_thread_.join(); } + const vec_col_index_t& indices_of_on_key() { return indices_of_on_key_; } + const std::vector& indices_of_by_key() { return indices_of_by_key_; } + static bool find_type(std::unordered_set> type_set, const std::shared_ptr& type) { for (auto ty : type_set) { @@ -692,40 +780,44 @@ class AsofJoinNode : public ExecNode { } static arrow::Result> MakeOutputSchema( - const std::vector& inputs, - const std::vector& on_key_col_indices, - const std::vector& by_key_col_indices) { + const std::vector& inputs, const vec_col_index_t& indices_of_on_key, + const std::vector& indices_of_by_key) { std::vector> fields; + size_t n_by = indices_of_by_key[0].size(); const DataType* on_key_type = NULLPTR; - const DataType* by_key_type = NULLPTR; + std::vector by_key_type(n_by, NULLPTR); // Take all non-key, non-time RHS fields for (size_t j = 0; j < inputs.size(); ++j) { const auto& input_schema = inputs[j]->output_schema(); - const auto& on_field_ix = on_key_col_indices[j]; - const auto& by_field_ix = by_key_col_indices[j]; + const auto& on_field_ix = indices_of_on_key[j]; + const auto& by_field_ix = indices_of_by_key[j]; - if ((on_field_ix == -1) | (by_field_ix == -1)) { + if ((on_field_ix == -1) || std_has(by_field_ix, -1)) { return Status::Invalid("Missing join key on table ", j); } const auto& on_field = input_schema->fields()[on_field_ix]; - const auto& by_field = input_schema->fields()[by_field_ix]; - const auto& on_field_type = on_field->type(); - const auto& by_field_type = by_field->type(); + std::vector by_field(n_by); + for (size_t k = 0; k < n_by; k++) { + by_field[k] = input_schema->fields()[by_field_ix[k]].get(); + } + if (on_key_type == NULLPTR) { - on_key_type = on_field_type.get(); - } else if (*on_key_type != *on_field_type) { + on_key_type = on_field->type().get(); + } else if (*on_key_type != *on_field->type()) { return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", - *on_field_type, " for field ", on_field->name(), + *on_field->type(), " for field ", on_field->name(), " in input ", j); } - if (by_key_type == NULLPTR) { - by_key_type = by_field_type.get(); - } else if (*by_key_type != *by_field_type) { - return Status::Invalid("Expected on-key type ", *by_key_type, " but got ", - *by_field_type, " for field ", by_field->name(), - " in input ", j); + for (size_t k = 0; k < n_by; k++) { + if (by_key_type[k] == NULLPTR) { + by_key_type[k] = by_field[k]->type().get(); + } else if (*by_key_type[k] != *by_field[k]->type()) { + return Status::Invalid("Expected on-key type ", *by_key_type[k], " but got ", + *by_field[k]->type(), " for field ", by_field[k]->name(), + " in input ", j); + } } for (int i = 0; i < input_schema->num_fields(); ++i) { @@ -739,7 +831,7 @@ class AsofJoinNode : public ExecNode { if (j == 0) { fields.push_back(field); } - } else if (i == by_field_ix) { + } else if (std_has(by_field_ix, i)) { if (!find_type(kSupportedByTypes_, field->type())) { return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", field->type()->ToString()); @@ -771,49 +863,62 @@ class AsofJoinNode : public ExecNode { join_options.tolerance); } - std::vector input_labels(inputs.size()); - std::vector on_key_col_indices(inputs.size()); - std::vector by_key_col_indices(inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { + size_t n_input = inputs.size(), n_by = join_options.by_key.size(); + std::vector input_labels(n_input); + vec_col_index_t indices_of_on_key(n_input); + std::vector indices_of_by_key(n_input, vec_col_index_t(n_by)); + for (size_t i = 0; i < n_input; ++i) { input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i); const auto& input_schema = inputs[i]->output_schema(); -#define ASOFJOIN_KEY_MATCH(k) \ - auto k##_key_match_res = join_options.k##_key.FindOne(*input_schema); \ - if (!k##_key_match_res.ok()) { \ - return Status::Invalid("Bad join key on table : ", \ - k##_key_match_res.status().message()); \ - } \ - auto k##_key_match = k##_key_match_res.ValueOrDie(); \ - if (k##_key_match.indices().size() != 1) { \ - return Status::Invalid("AsOfJoinNode does not support a nested " #k "-key ", \ - join_options.k##_key.ToString()); \ - } \ - k##_key_col_indices[i] = k##_key_match.indices()[0]; - - ASOFJOIN_KEY_MATCH(on) - ASOFJOIN_KEY_MATCH(by) +#define ASOFJOIN_KEY_MATCH(kopt, kacc) \ + auto kopt##_match_res = (join_options.kopt)kacc.FindOne(*input_schema); \ + if (!kopt##_match_res.ok()) { \ + return Status::Invalid("Bad join key on table : ", \ + kopt##_match_res.status().message()); \ + } \ + auto kopt##_match = kopt##_match_res.ValueOrDie(); \ + if (kopt##_match.indices().size() != 1) { \ + return Status::Invalid("AsOfJoinNode does not support a nested " #kopt "-key ", \ + (join_options.kopt)kacc.ToString()); \ + } \ + (indices_of_##kopt[i]) kacc = kopt##_match.indices()[0]; + + ASOFJOIN_KEY_MATCH(on_key, ) + for (size_t k = 0; k < n_by; k++) { + ASOFJOIN_KEY_MATCH(by_key, [k]) + } #undef ASOFJOIN_KEY_MATCH } ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, - MakeOutputSchema(inputs, on_key_col_indices, - by_key_col_indices)); - - return plan->EmplaceNode(plan, inputs, std::move(input_labels), - std::move(on_key_col_indices), - std::move(by_key_col_indices), - time_value(join_options.tolerance, 0), - std::move(output_schema)); + MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key)); + + auto node = plan->EmplaceNode( + plan, inputs, std::move(input_labels), std::move(indices_of_on_key), + std::move(indices_of_by_key), time_value(join_options.tolerance), + std::move(output_schema)); + auto node_output_schema = node->output_schema(); + auto node_indices_of_by_key = checked_cast(node)->indices_of_by_key(); + std::vector> key_hashers; + if (n_by > 1) { + for (size_t i = 0; i < n_input; i++) { + key_hashers.push_back( + ::arrow::internal::make_unique(node_indices_of_by_key[i])); + RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); + } + } + RETURN_NOT_OK(node->Init(std::move(key_hashers))); + return node; } const char* kind_name() const override { return "AsofJoinNode"; } void InputReceived(ExecNode* input, ExecBatch batch) override { // Get the input - ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); @@ -827,8 +932,8 @@ class AsofJoinNode : public ExecNode { void InputFinished(ExecNode* input, int total_batches) override { { std::lock_guard guard(gate_); - ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); state_.at(k)->set_total_batches(total_batches); } // Trigger a process call @@ -856,6 +961,9 @@ class AsofJoinNode : public ExecNode { static const std::unordered_set> kSupportedDataTypes_; arrow::Future<> finished_; + std::vector> key_hashers_; + vec_col_index_t indices_of_on_key_; + std::vector indices_of_by_key_; // InputStates // Each input state correponds to an input table std::vector> state_; @@ -874,23 +982,17 @@ class AsofJoinNode : public ExecNode { AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const std::vector& on_key_col_indices, - const std::vector& by_key_col_indices, + const vec_col_index_t& indices_of_on_key, + const std::vector& indices_of_by_key, TimeType tolerance, std::shared_ptr output_schema) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), + indices_of_on_key_(std::move(indices_of_on_key)), + indices_of_by_key_(std::move(indices_of_by_key)), tolerance_(tolerance), process_(), process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { - for (size_t i = 0; i < inputs.size(); ++i) { - state_.push_back(::arrow::internal::make_unique( - inputs[i]->output_schema(), on_key_col_indices[i], by_key_col_indices[i])); - } - col_index_t dst_offset = 0; - for (auto& state : state_) - dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); - finished_ = arrow::Future<>::MakeFinished(); } diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 614b0e3848391..65e39c748655d 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -22,6 +22,7 @@ #include #include "arrow/api.h" +#include "arrow/compute/api_scalar.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/test_util.h" #include "arrow/compute/exec/util.h" @@ -39,16 +40,50 @@ using testing::UnorderedElementsAreArray; namespace arrow { namespace compute { +// mutates by copying from_key into to_key and changing from_key to zero +BatchesWithSchema MutateByKey(const BatchesWithSchema& batches, std::string from_key, + std::string to_key, bool replace_key = false) { + int from_index = batches.schema->GetFieldIndex(from_key); + int n_fields = batches.schema->num_fields(); + BatchesWithSchema new_batches; + auto new_field = batches.schema->field(from_index)->WithName(to_key); + new_batches.schema = (replace_key ? batches.schema->SetField(from_index, new_field) + : batches.schema->AddField(from_index, new_field)) + .ValueOrDie(); + for (const ExecBatch& batch : batches.batches) { + std::vector new_values; + for (int i = 0; i < n_fields; i++) { + const Datum& value = batch.values[i]; + if (i == from_index) { + new_values.push_back(Subtract(value, value).ValueOrDie()); + if (replace_key) { + continue; + } + } + new_values.push_back(value); + } + new_batches.batches.emplace_back(new_values, batch.length); + } + return new_batches; +} + +// code generation for the by_key types supported by AsofJoinNodeOptions constructors +// which cannot be directly done using templates because of failure to deduce the template +// argument for an invocation with a string- or initializer_list-typed keys-argument +#define EXPAND_BY_KEY_TYPE(macro) \ + macro(const FieldRef); \ + macro(std::vector); \ + macro(std::initializer_list); + void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, const BatchesWithSchema& r1_batches, - const BatchesWithSchema& exp_batches, const FieldRef time, - const FieldRef keys, const int64_t tolerance) { + const BatchesWithSchema& exp_batches, + const AsofJoinNodeOptions join_options) { auto exec_ctx = arrow::internal::make_unique(default_memory_pool(), nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsofJoinNodeOptions join_options(time, keys, tolerance); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ @@ -74,102 +109,20 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, /*same_chunk_layout=*/true, /*flatten=*/true); } -struct BasicTestTypes { - std::shared_ptr time, key, l_val, r0_val, r1_val; -}; - -void DoRunBasicTestTypes(const std::vector& l_data, - const std::vector& r0_data, - const std::vector& r1_data, - const std::vector& exp_data, - int64_t tolerance, BasicTestTypes basic_test_types) { - const BasicTestTypes& b = basic_test_types; - auto l_schema = - schema({field("time", b.time), field("key", b.key), field("l_v0", b.l_val)}); - auto r0_schema = - schema({field("time", b.time), field("key", b.key), field("r0_v0", b.r0_val)}); - auto r1_schema = - schema({field("time", b.time), field("key", b.key), field("r1_v0", b.r1_val)}); - - auto exp_schema = schema({ - field("time", b.time), - field("key", b.key), - field("l_v0", b.l_val), - field("r0_v0", b.r0_val), - field("r1_v0", b.r1_val), - }); - - // Test three table join - BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; - l_batches = MakeBatchesFromString(l_schema, l_data); - r0_batches = MakeBatchesFromString(r0_schema, r0_data); - r1_batches = MakeBatchesFromString(r1_schema, r1_data); - exp_batches = MakeBatchesFromString(exp_schema, exp_data); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", - tolerance); -} - -static inline void init_types( - const std::vector>& all_types, - std::vector>& types, - std::function)> type_cond) { - if (types.size() == 0) { - for (auto type : all_types) { - if (type_cond(type)) { - types.push_back(type); - } - } +#define CHECK_RUN_OUTPUT(by_key_type) \ + void CheckRunOutput( \ + const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, \ + const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \ + const FieldRef time, by_key_type keys, const int64_t tolerance) { \ + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \ + AsofJoinNodeOptions(time, keys, tolerance)); \ } -} -void DoRunBasicTest(const std::vector& l_data, - const std::vector& r0_data, - const std::vector& r1_data, - const std::vector& exp_data, int64_t tolerance, - std::vector> time_types = {}, - std::vector> key_types = {}, - std::vector> l_types = {}, - std::vector> r0_types = {}, - std::vector> r1_types = {}) { - std::vector> all_types = {int8(), - int16(), - int32(), - int64(), - uint8(), - uint16(), - uint32(), - uint64(), - timestamp(TimeUnit::NANO, "UTC"), - timestamp(TimeUnit::MICRO, "UTC"), - timestamp(TimeUnit::MILLI, "UTC"), - timestamp(TimeUnit::SECOND, "UTC"), - float32(), - float64()}; - using T = const std::shared_ptr; - init_types(all_types, time_types, - [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); - init_types(all_types, key_types, [](T& t) { return is_integer(t->id()); }); - init_types(all_types, l_types, - [](T& t) { return t->byte_width() > 1 && t->id() != Type::TIMESTAMP; }); - init_types(all_types, r0_types, [](T& t) { return is_floating(t->id()); }); - init_types(all_types, r1_types, [](T& t) { return is_floating(t->id()); }); - for (auto time_type : time_types) { - for (auto key_type : key_types) { - for (auto l_type : l_types) { - for (auto r0_type : r0_types) { - for (auto r1_type : r1_types) { - DoRunBasicTestTypes(l_data, r0_data, r1_data, exp_data, tolerance, - {time_type, key_type, l_type, r0_type, r1_type}); - } - } - } - } - } -} +EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema, - FieldRef on_key, FieldRef by_key, int64_t tolerance, + const std::shared_ptr& r_schema, FieldRef on_key, + FieldRef by_key, int64_t tolerance, const std::string& expected_error_str) { BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); @@ -184,13 +137,13 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr(expected_error_str), join.AddToPlan(plan.get())); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_error_str), + join.AddToPlan(plan.get())); } void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema, - int64_t tolerance, const std::string& expected_error_str) { + const std::shared_ptr& r_schema, int64_t tolerance, + const std::string& expected_error_str) { DoRunInvalidPlanTest(l_schema, r_schema, "time", "key", tolerance, expected_error_str); } @@ -244,38 +197,226 @@ void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); } +struct BasicTestTypes { + std::shared_ptr time, key, l_val, r0_val, r1_val; +}; + +struct BasicTest { + BasicTest(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_nokey_data, + const std::vector& exp_data, int64_t tolerance) + : l_data(std::move(l_data)), + r0_data(std::move(r0_data)), + r1_data(std::move(r1_data)), + exp_nokey_data(std::move(exp_nokey_data)), + exp_data(std::move(exp_data)), + tolerance(tolerance) {} + + template + static inline void init_types(const std::vector>& all_types, + std::vector>& types, + TypeCond type_cond) { + if (types.size() == 0) { + for (auto type : all_types) { + if (type_cond(type)) { + types.push_back(type); + } + } + } + } + + void Run() { + RunSingleByKey(); + RunDoubleByKey(); + } + void RunSingleByKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", + tolerance); + }); + } + void RunDoubleByKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", + {"key", "key"}, tolerance); + }); + } + void RunMutateByKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + l_batches = MutateByKey(l_batches, "key", "key2"); + r0_batches = MutateByKey(r0_batches, "key", "key2"); + r1_batches = MutateByKey(r1_batches, "key", "key2"); + exp_batches = MutateByKey(exp_batches, "key", "key2"); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", + {"key", "key2"}, tolerance); + }); + } + void RunMutateNoKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + l_batches = MutateByKey(l_batches, "key", "key2", true); + r0_batches = MutateByKey(r0_batches, "key", "key2", true); + r1_batches = MutateByKey(r1_batches, "key", "key2", true); + exp_batches = MutateByKey(exp_batches, "key", "key2", true); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, "time", "key2", + tolerance); + }); + } + template + void RunBatches(BatchesRunner batches_runner, + std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + std::vector> all_types = { + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC"), + float32(), + float64()}; + using T = const std::shared_ptr; + init_types(all_types, time_types, + [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); + init_types(all_types, key_types, [](T& t) { return is_integer(t->id()); }); + init_types(all_types, l_types, + [](T& t) { return t->byte_width() > 1 && t->id() != Type::TIMESTAMP; }); + init_types(all_types, r0_types, [](T& t) { return is_floating(t->id()); }); + init_types(all_types, r1_types, [](T& t) { return is_floating(t->id()); }); + for (auto time_type : time_types) { + for (auto key_type : key_types) { + for (auto l_type : l_types) { + for (auto r0_type : r0_types) { + for (auto r1_type : r1_types) { + RunTypes({time_type, key_type, l_type, r0_type, r1_type}, batches_runner); + } + } + } + } + } + } + template + void RunTypes(BasicTestTypes basic_test_types, BatchesRunner batches_runner) { + const BasicTestTypes& b = basic_test_types; + auto l_schema = + schema({field("time", b.time), field("key", b.key), field("l_v0", b.l_val)}); + auto r0_schema = + schema({field("time", b.time), field("key", b.key), field("r0_v0", b.r0_val)}); + auto r1_schema = + schema({field("time", b.time), field("key", b.key), field("r1_v0", b.r1_val)}); + + auto exp_schema = schema({ + field("time", b.time), + field("key", b.key), + field("l_v0", b.l_val), + field("r0_v0", b.r0_val), + field("r1_v0", b.r1_val), + }); + + // Test three table join + BatchesWithSchema l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches; + l_batches = MakeBatchesFromString(l_schema, l_data); + r0_batches = MakeBatchesFromString(r0_schema, r0_data); + r1_batches = MakeBatchesFromString(r1_schema, r1_data); + exp_nokey_batches = MakeBatchesFromString(exp_schema, exp_nokey_data); + exp_batches = MakeBatchesFromString(exp_schema, exp_data); + batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches); + } + + std::vector l_data; + std::vector r0_data; + std::vector r1_data; + std::vector exp_nokey_data; + std::vector exp_data; + int64_t tolerance; +}; + class AsofJoinTest : public testing::Test {}; -TEST(AsofJoinTest, TestBasic1) { +#define ASOFJOIN_TEST_SET(name, num) \ + TEST(AsofJoinTest, Test##name##num##_SingleByKey) { \ + Get##name##Test##num().RunSingleByKey(); \ + } \ + TEST(AsofJoinTest, Test##name##num##_DoubleByKey) { \ + Get##name##Test##num().RunDoubleByKey(); \ + } \ + TEST(AsofJoinTest, Test##name##num##_MutateByKey) { \ + Get##name##Test##num().RunMutateByKey(); \ + } \ + TEST(AsofJoinTest, Test##name##num##_MutateNoKey) { \ + Get##name##Test##num().RunMutateNoKey(); \ + } + +BasicTest GetBasicTest1() { // Single key, single batch - DoRunBasicTest( + return BasicTest( /*l*/ {R"([[0, 1, 1], [1000, 1, 2]])"}, /*r0*/ {R"([[0, 1, 11]])"}, /*r1*/ {R"([[1000, 1, 101]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, null], [1000, 0, 2, 11, 101]])"}, /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, 1000); } +ASOFJOIN_TEST_SET(Basic, 1) -TEST(AsofJoinTest, TestBasic2) { +BasicTest GetBasicTest2() { // Single key, multiple batches - DoRunBasicTest( + return BasicTest( /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, /*r0*/ {R"([[0, 1, 11]])", R"([[1000, 1, 12]])"}, /*r1*/ {R"([[0, 1, 101]])", R"([[1000, 1, 102]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } +ASOFJOIN_TEST_SET(Basic, 2) -TEST(AsofJoinTest, TestBasic3) { +BasicTest GetBasicTest3() { // Single key, multiple left batches, single right batches - DoRunBasicTest( + return BasicTest( /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, /*r0*/ {R"([[0, 1, 11], [1000, 1, 12]])"}, /*r1*/ {R"([[0, 1, 101], [1000, 1, 102]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } +ASOFJOIN_TEST_SET(Basic, 3) -TEST(AsofJoinTest, TestBasic4) { +BasicTest GetBasicTest4() { // Multi key, multiple batches, misaligned batches - DoRunBasicTest( + return BasicTest( /*l*/ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", R"([[2000, 1, 4], [2000, 2, 24]])"}, @@ -285,118 +426,146 @@ TEST(AsofJoinTest, TestBasic4) { /*r1*/ {R"([[0, 2, 1001], [500, 1, 101]])", R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, /*exp*/ {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, 1001], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } +ASOFJOIN_TEST_SET(Basic, 4) -TEST(AsofJoinTest, TestBasic5) { +BasicTest GetBasicTest5() { // Multi key, multiple batches, misaligned batches, smaller tolerance - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", - R"([[2000, 1, 4], [2000, 2, 24]])"}, - /*r0*/ - {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", - R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, - /*r1*/ - {R"([[0, 2, 1001], [500, 1, 101]])", - R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, - /*exp*/ - {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, null], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", - R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, - 500); -} - -TEST(AsofJoinTest, TestBasic6) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, null], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, + 500); +} +ASOFJOIN_TEST_SET(Basic, 5) + +BasicTest GetBasicTest6() { // Multi key, multiple batches, misaligned batches, zero tolerance - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", - R"([[2000, 1, 4], [2000, 2, 24]])"}, - /*r0*/ - {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", - R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, - /*r1*/ - {R"([[0, 2, 1001], [500, 1, 101]])", - R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, - /*exp*/ - {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, null], [1500, 1, 3, null, null], [1500, 2, 23, 32, 1002]])", - R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"}, - 0); -} - -TEST(AsofJoinTest, TestEmpty1) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, null], [1500, 1, 3, null, null], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"}, + 0); +} +ASOFJOIN_TEST_SET(Basic, 6) + +BasicTest GetEmptyTest1() { // Empty left batch - DoRunBasicTest(/*l*/ - {R"([])", R"([[2000, 1, 4], [2000, 2, 24]])"}, - /*r0*/ - {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", - R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, - /*r1*/ - {R"([[0, 2, 1001], [500, 1, 101]])", - R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, - /*exp*/ - {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty2) { + return BasicTest(/*l*/ + {R"([])", R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp*/ + {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); +} +ASOFJOIN_TEST_SET(Empty, 1) + +BasicTest GetEmptyTest2() { // Empty left input - DoRunBasicTest(/*l*/ - {R"([])"}, - /*r0*/ - {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", - R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, - /*r1*/ - {R"([[0, 2, 1001], [500, 1, 101]])", - R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, - /*exp*/ - {R"([])"}, 1000); -} - -TEST(AsofJoinTest, TestEmpty3) { + return BasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); +} +ASOFJOIN_TEST_SET(Empty, 2) + +BasicTest GetEmptyTest3() { // Empty right batch - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", - R"([[2000, 1, 4], [2000, 2, 24]])"}, - /*r0*/ - {R"([])", R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, - /*r1*/ - {R"([[0, 2, 1001], [500, 1, 101]])", - R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, - /*exp*/ - {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, 32, 1002]])", - R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty4) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([])", R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, + 1000); +} +ASOFJOIN_TEST_SET(Empty, 3) + +BasicTest GetEmptyTest4() { // Empty right input - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", - R"([[2000, 1, 4], [2000, 2, 24]])"}, - /*r0*/ - {R"([])"}, - /*r1*/ - {R"([[0, 2, 1001], [500, 1, 101]])", - R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, - /*exp*/ - {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, null, 1002]])", - R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty5) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, null, 1002], [1500, 0, 23, null, 1002]])", + R"([[2000, 0, 4, null, 103], [2000, 0, 24, null, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, null, 1002]])", + R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"}, + 1000); +} +ASOFJOIN_TEST_SET(Empty, 4) + +BasicTest GetEmptyTest5() { // All empty - DoRunBasicTest(/*l*/ - {R"([])"}, - /*r0*/ - {R"([])"}, - /*r1*/ - {R"([])"}, - /*exp*/ - {R"([])"}, 1000); -} + return BasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([])"}, + /*exp_nokey*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); +} +ASOFJOIN_TEST_SET(Empty, 5) TEST(AsofJoinTest, TestUnsupportedOntype) { DoRunInvalidTypeTest( diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 1e7ba01b2f2b3..e9410a2d69585 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -395,8 +395,18 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(std::move(by_key)), tolerance(tolerance) {} + AsofJoinNodeOptions(FieldRef on_key, const FieldRef& by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(), tolerance(tolerance) { + this->by_key.push_back(std::move(by_key)); + } + + AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} + + // resolves ambiguity between previous constructors when initializer list is given + AsofJoinNodeOptions(FieldRef on_key, std::initializer_list by_key, + int64_t tolerance) + : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} /// \brief "on" key for the join. Each /// @@ -410,7 +420,7 @@ class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { /// All input tables must have the "by" key. Exact equality /// is used for the "by" key. /// Currently, the "by" key must be of an integer or timestamp type - FieldRef by_key; + std::vector by_key; /// Tolerance for inexact "on" key matching. Must be non-negative. int64_t tolerance; }; diff --git a/cpp/src/arrow/compute/light_array.cc b/cpp/src/arrow/compute/light_array.cc index 4bf3574d09fdb..9ea609c531048 100644 --- a/cpp/src/arrow/compute/light_array.cc +++ b/cpp/src/arrow/compute/light_array.cc @@ -141,6 +141,12 @@ Result ColumnArrayFromArrayData( const std::shared_ptr& array_data, int64_t start_row, int64_t num_rows) { ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata metadata, ColumnMetadataFromDataType(array_data->type)); + return ColumnArrayFromArrayDataAndMetadata(array_data, metadata, start_row, num_rows); +} + +KeyColumnArray ColumnArrayFromArrayDataAndMetadata( + const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, + int64_t start_row, int64_t num_rows) { KeyColumnArray column_array = KeyColumnArray( metadata, array_data->offset + start_row + num_rows, array_data->buffers[0] != NULLPTR ? array_data->buffers[0]->data() : nullptr, diff --git a/cpp/src/arrow/compute/light_array.h b/cpp/src/arrow/compute/light_array.h index f0e5c7068716a..6a5c205a7a48b 100644 --- a/cpp/src/arrow/compute/light_array.h +++ b/cpp/src/arrow/compute/light_array.h @@ -187,6 +187,17 @@ ARROW_EXPORT Result ColumnMetadataFromDataType( ARROW_EXPORT Result ColumnArrayFromArrayData( const std::shared_ptr& array_data, int64_t start_row, int64_t num_rows); +/// \brief Create KeyColumnArray from ArrayData and KeyColumnMetadata +/// +/// If `type` is a dictionary type then this will return the KeyColumnArray for +/// the indices array +/// +/// The caller should ensure this is only called on "key" columns. +/// \see ColumnMetadataFromDataType for details +ARROW_EXPORT KeyColumnArray ColumnArrayFromArrayDataAndMetadata( + const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, + int64_t start_row, int64_t num_rows); + /// \brief Create KeyColumnMetadata instances from an ExecBatch /// /// column_metadatas will be resized to fit From e8177a3775b579cecdc74c6e4ee0e24de84ff9bf Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 15 Aug 2022 09:15:47 -0400 Subject: [PATCH 06/26] ARROW-17412: [C++] AsofJoin multiple keys and types --- cpp/src/arrow/compute/exec/asof_join_node.cc | 488 ++++++++++--- .../arrow/compute/exec/asof_join_node_test.cc | 665 +++++++++++++----- cpp/src/arrow/compute/exec/hash_join.cc | 1 - cpp/src/arrow/compute/exec/options.h | 26 +- cpp/src/arrow/compute/light_array.cc | 6 + cpp/src/arrow/compute/light_array.h | 11 + 6 files changed, 917 insertions(+), 280 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 3da612aa03e41..2319ba8249910 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -17,15 +17,17 @@ #include #include -#include #include #include +#include #include "arrow/array/builder_primitive.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/util.h" +#include "arrow/compute/light_array.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" @@ -37,13 +39,41 @@ namespace arrow { namespace compute { -// Remove this when multiple keys and/or types is supported -typedef int32_t KeyType; +template +inline typename T::const_iterator std_find(const T& container, const V& val) { + return std::find(container.begin(), container.end(), val); +} + +template +inline bool std_has(const T& container, const V& val) { + return container.end() != std_find(container, val); +} + +typedef uint64_t KeyType; +typedef uint64_t TimeType; +typedef uint64_t HashType; // Maximum number of tables that can be joined #define MAX_JOIN_TABLES 64 typedef uint64_t row_index_t; typedef int col_index_t; +typedef std::vector vec_col_index_t; + +template ::value, bool> = true> +static inline uint64_t norm_value(T t) { + uint64_t bias = std::is_signed::value ? (uint64_t)1 << (8 * sizeof(T) - 1) : 0; + return t < 0 ? static_cast(t + bias) : static_cast(t); +} + +template ::value, bool> = true> +static inline uint64_t time_value(T t) { + return norm_value(t); +} + +template ::value, bool> = true> +static inline uint64_t key_value(T t) { + return norm_value(t); +} /** * Simple implementation for an unbound concurrent queue @@ -99,7 +129,7 @@ struct MemoStore { struct Entry { // Timestamp associated with the entry - int64_t time; + TimeType time; // Batch associated with the entry (perf is probably OK for this; batches change // rarely) @@ -111,7 +141,7 @@ struct MemoStore { std::unordered_map entries_; - void Store(const std::shared_ptr& batch, row_index_t row, int64_t time, + void Store(const std::shared_ptr& batch, row_index_t row, TimeType time, KeyType key) { auto& e = entries_[key]; // that we can do this assignment optionally, is why we @@ -128,7 +158,7 @@ struct MemoStore { return util::optional(&e->second); } - void RemoveEntriesWithLesserTime(int64_t ts) { + void RemoveEntriesWithLesserTime(TimeType ts) { for (auto e = entries_.begin(); e != entries_.end();) if (e->second.time < ts) e = entries_.erase(e); @@ -137,18 +167,84 @@ struct MemoStore { } }; +// a specialized higher-performance variation of Hashing64 logic +class KeyHasher { + static constexpr int kMiniBatchLength = util::MiniBatch::kMiniBatchLength; + + public: + explicit KeyHasher(const vec_col_index_t& indices) + : indices_(indices), + metadata_(indices.size()), + batch_(NULLPTR), + hashes_(), + ctx_(), + column_arrays_(), + stack_() { + ctx_.stack = &stack_; + column_arrays_.resize(indices.size()); + } + + Status Init(ExecContext* exec_context, const std::shared_ptr& schema) { + ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags(); + const auto& fields = schema->fields(); + for (size_t k = 0; k < metadata_.size(); k++) { + ARROW_ASSIGN_OR_RAISE(metadata_[k], + ColumnMetadataFromDataType(fields[indices_[k]]->type())); + } + return stack_.Init(exec_context->memory_pool(), + 4 * kMiniBatchLength * sizeof(uint32_t)); + } + + const std::vector& HashesFor(const RecordBatch* batch) { + if (batch_ == batch) { + return hashes_; + } + batch_ = NULLPTR; // invalidate cached hashes for batch + size_t batch_length = batch->num_rows(); + hashes_.resize(batch_length); + for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { + int64_t length = std::min(static_cast(batch_length - i), + static_cast(kMiniBatchLength)); + for (size_t k = 0; k < indices_.size(); k++) { + auto array_data = batch->column_data(indices_[k]); + column_arrays_[k] = + ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); + } + Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); + } + batch_ = batch; + return hashes_; + } + + private: + vec_col_index_t indices_; + std::vector metadata_; + const RecordBatch* batch_; + std::vector hashes_; + LightContext ctx_; + std::vector column_arrays_; + util::TempVectorStack stack_; +}; + class InputState { // InputState correponds to an input // Input record batches are queued up in InputState until processed and // turned into output record batches. public: - InputState(const std::shared_ptr& schema, - const std::string& time_col_name, const std::string& key_col_name) + InputState(KeyHasher* key_hasher, const std::shared_ptr& schema, + const col_index_t time_col_index, const vec_col_index_t& key_col_index) : queue_(), schema_(schema), - time_col_index_(schema->GetFieldIndex(time_col_name)), - key_col_index_(schema->GetFieldIndex(key_col_name)) {} + time_col_index_(time_col_index), + key_col_index_(key_col_index), + time_type_id_(schema_->fields()[time_col_index_]->type()->id()), + key_type_id_(schema_->num_fields()), + key_hasher_(key_hasher) { + for (size_t k = 0; k < key_col_index_.size(); k++) { + key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); + } + } col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); @@ -164,7 +260,7 @@ class InputState { bool IsTimeOrKeyColumn(col_index_t i) const { DCHECK_LT(i, schema_->num_fields()); - return (i == time_col_index_) || (i == key_col_index_); + return (i == time_col_index_) || std_has(key_col_index_, i); } // Gets the latest row index, assuming the queue isn't empty @@ -184,18 +280,51 @@ class InputState { return queue_.UnsyncFront(); } +#define LATEST_VAL_CASE(id, val) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + using CType = typename TypeTraits::CType; \ + return val(data->GetValues(1)[latest_ref_row_]); \ + } + KeyType GetLatestKey() const { - return queue_.UnsyncFront() - ->column_data(key_col_index_) - ->GetValues(1)[latest_ref_row_]; + if (key_hasher_ != NULLPTR) { + return key_hasher_->HashesFor(queue_.UnsyncFront().get())[latest_ref_row_]; + } + auto data = queue_.UnsyncFront()->column_data(key_col_index_[0]); + switch (key_type_id_[0]) { + LATEST_VAL_CASE(INT8, key_value) + LATEST_VAL_CASE(INT16, key_value) + LATEST_VAL_CASE(INT32, key_value) + LATEST_VAL_CASE(INT64, key_value) + LATEST_VAL_CASE(UINT8, key_value) + LATEST_VAL_CASE(UINT16, key_value) + LATEST_VAL_CASE(UINT32, key_value) + LATEST_VAL_CASE(UINT64, key_value) + default: + return 0; // cannot happen + } } - int64_t GetLatestTime() const { - return queue_.UnsyncFront() - ->column_data(time_col_index_) - ->GetValues(1)[latest_ref_row_]; + TimeType GetLatestTime() const { + auto data = queue_.UnsyncFront()->column_data(time_col_index_); + switch (time_type_id_) { + LATEST_VAL_CASE(INT8, time_value) + LATEST_VAL_CASE(INT16, time_value) + LATEST_VAL_CASE(INT32, time_value) + LATEST_VAL_CASE(INT64, time_value) + LATEST_VAL_CASE(UINT8, time_value) + LATEST_VAL_CASE(UINT16, time_value) + LATEST_VAL_CASE(UINT32, time_value) + LATEST_VAL_CASE(UINT64, time_value) + LATEST_VAL_CASE(TIMESTAMP, time_value) + default: + return 0; // cannot happen + } } +#undef LATEST_VAL_CASE + bool Finished() const { return batches_processed_ == total_batches_; } bool Advance() { @@ -222,28 +351,25 @@ class InputState { // latest_time and latest_ref_row to the value that immediately pass the // specified timestamp. // Returns true if updates were made, false if not. - bool AdvanceAndMemoize(int64_t ts) { + bool AdvanceAndMemoize(TimeType ts) { // Advance the right side row index until we reach the latest right row (for each key) // for the given left timestamp. // Check if already updated for TS (or if there is no latest) if (Empty()) return false; // can't advance if empty - auto latest_time = GetLatestTime(); - if (latest_time > ts) return false; // already advanced // Not updated. Try to update and possibly advance. bool updated = false; do { - latest_time = GetLatestTime(); + auto latest_time = GetLatestTime(); // if Advance() returns true, then the latest_ts must also be valid // Keep advancing right table until we hit the latest row that has // timestamp <= ts. This is because we only need the latest row for the // match given a left ts. - if (latest_time <= ts) { - memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); - } else { + if (latest_time > ts) { break; // hit a future timestamp -- done updating for now } + memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); updated = true; } while (Advance()); return updated; @@ -261,7 +387,7 @@ class InputState { return memo_.GetEntryForKey(key); } - util::optional GetMemoTimeForKey(KeyType key) { + util::optional GetMemoTimeForKey(KeyType key) { auto r = GetMemoEntryForKey(key); if (r.has_value()) { return (*r)->time; @@ -270,7 +396,7 @@ class InputState { } } - void RemoveMemoEntriesWithLesserTime(int64_t ts) { + void RemoveMemoEntriesWithLesserTime(TimeType ts) { memo_.RemoveEntriesWithLesserTime(ts); } @@ -294,7 +420,13 @@ class InputState { // Index of the time col col_index_t time_col_index_; // Index of the key col - col_index_t key_col_index_; + vec_col_index_t key_col_index_; + // Type id of the time column + Type::type time_type_id_; + // Type id of the key column + std::vector key_type_id_; + // Buffer for key elements + mutable KeyHasher* key_hasher_; // Index of the latest row reference within; if >0 then queue_ cannot be empty // Must be < queue_.front()->num_rows() if queue_ is non-empty row_index_t latest_ref_row_ = 0; @@ -336,7 +468,7 @@ class CompositeReferenceTable { // Adds the latest row from the input state as a new composite reference row // - LHS must have a valid key,timestep,and latest rows // - RHS must have valid data memo'ed for the key - void Emplace(std::vector>& in, int64_t tolerance) { + void Emplace(std::vector>& in, TimeType tolerance) { DCHECK_EQ(in.size(), n_tables_); // Get the LHS key @@ -347,7 +479,7 @@ class CompositeReferenceTable { DCHECK(!in[0]->Empty()); const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); row_index_t lhs_latest_row = in[0]->GetLatestRow(); - int64_t lhs_latest_time = in[0]->GetLatestTime(); + TimeType lhs_latest_time = in[0]->GetLatestTime(); if (0 == lhs_latest_row) { // On the first row of the batch, we resize the destination. // The destination size is dictated by the size of the LHS batch. @@ -407,29 +539,34 @@ class CompositeReferenceTable { DCHECK_EQ(src_field->name(), dst_field->name()); const auto& field_type = src_field->type(); - if (field_type->Equals(arrow::int32())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::int64())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::float32())) { - ARROW_ASSIGN_OR_RAISE(arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::float64())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else { - ARROW_RETURN_NOT_OK( - Status::Invalid("Unsupported data type: ", src_field->name())); +#define ASOFJOIN_MATERIALIZE_CASE(id) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + ARROW_ASSIGN_OR_RAISE( \ + arrays.at(i_dst_col), \ + MaterializeColumn(memory_pool, field_type, i_table, i_src_col)); \ + break; \ + } + + switch (field_type->id()) { + ASOFJOIN_MATERIALIZE_CASE(INT8) + ASOFJOIN_MATERIALIZE_CASE(INT16) + ASOFJOIN_MATERIALIZE_CASE(INT32) + ASOFJOIN_MATERIALIZE_CASE(INT64) + ASOFJOIN_MATERIALIZE_CASE(UINT8) + ASOFJOIN_MATERIALIZE_CASE(UINT16) + ASOFJOIN_MATERIALIZE_CASE(UINT32) + ASOFJOIN_MATERIALIZE_CASE(UINT64) + ASOFJOIN_MATERIALIZE_CASE(FLOAT) + ASOFJOIN_MATERIALIZE_CASE(DOUBLE) + ASOFJOIN_MATERIALIZE_CASE(TIMESTAMP) + default: + return Status::Invalid("Unsupported data type ", + src_field->type()->ToString(), " for field ", + src_field->name()); } + +#undef ASOFJOIN_MATERIALIZE_CASE } } } @@ -459,11 +596,13 @@ class CompositeReferenceTable { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } - template - Result> MaterializePrimitiveColumn(MemoryPool* memory_pool, - size_t i_table, - col_index_t i_col) { - Builder builder(memory_pool); + template ::BuilderType, + class PrimitiveType = typename TypeTraits::CType> + Result> MaterializeColumn(MemoryPool* memory_pool, + const std::shared_ptr& type, + size_t i_table, col_index_t i_col) { + ARROW_ASSIGN_OR_RAISE(auto a_builder, MakeBuilder(type, memory_pool)); + Builder& builder = *checked_cast(a_builder.get()); ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size())); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; @@ -495,7 +634,7 @@ class AsofJoinNode : public ExecNode { bool IsUpToDateWithLhsRow() const { auto& lhs = *state_[0]; if (lhs.Empty()) return false; // can't proceed if nothing on the LHS - int64_t lhs_ts = lhs.GetLatestTime(); + TimeType lhs_ts = lhs.GetLatestTime(); for (size_t i = 1; i < state_.size(); ++i) { auto& rhs = *state_[i]; if (!rhs.Finished()) { @@ -531,7 +670,7 @@ class AsofJoinNode : public ExecNode { // the LHS and adding joined row to rows_ (done by Emplace). Finally, // input batches that are no longer needed are removed to free up memory. if (IsUpToDateWithLhsRow()) { - dst.Emplace(state_, options_.tolerance); + dst.Emplace(state_, tolerance_); if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch } else { if (!any_rhs_advanced) break; // need to wait for new data @@ -541,8 +680,7 @@ class AsofJoinNode : public ExecNode { // Prune memo entries that have expired (to bound memory consumption) if (!lhs.Empty()) { for (size_t i = 1; i < state_.size(); ++i) { - state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - - options_.tolerance); + state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - tolerance_); } } @@ -602,52 +740,110 @@ class AsofJoinNode : public ExecNode { public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const AsofJoinNodeOptions& join_options, + const vec_col_index_t& indices_of_on_key, + const std::vector& indices_of_by_key, TimeType tolerance, std::shared_ptr output_schema); + Status Init(std::vector> key_hashers) { + key_hashers_.swap(key_hashers); + bool has_kp = key_hashers_.size() > 0; + auto inputs = this->inputs(); + for (size_t i = 0; i < inputs.size(); ++i) { + state_.push_back(::arrow::internal::make_unique( + has_kp ? key_hashers_[i].get() : NULLPTR, inputs[i]->output_schema(), + indices_of_on_key_[i], indices_of_by_key_[i])); + } + + col_index_t dst_offset = 0; + for (auto& state : state_) + dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + + return Status::OK(); + } + virtual ~AsofJoinNode() { process_.Push(false); // poison pill process_thread_.join(); } + const vec_col_index_t& indices_of_on_key() { return indices_of_on_key_; } + const std::vector& indices_of_by_key() { return indices_of_by_key_; } + + static bool find_type(std::unordered_set> type_set, + const std::shared_ptr& type) { + for (auto ty : type_set) { + if (*ty.get() == *type.get()) { + return true; + } + } + return false; + } + static arrow::Result> MakeOutputSchema( - const std::vector& inputs, const AsofJoinNodeOptions& options) { + const std::vector& inputs, const vec_col_index_t& indices_of_on_key, + const std::vector& indices_of_by_key) { std::vector> fields; - const auto& on_field_name = *options.on_key.name(); - const auto& by_field_name = *options.by_key.name(); - + size_t n_by = indices_of_by_key[0].size(); + const DataType* on_key_type = NULLPTR; + std::vector by_key_type(n_by, NULLPTR); // Take all non-key, non-time RHS fields for (size_t j = 0; j < inputs.size(); ++j) { const auto& input_schema = inputs[j]->output_schema(); - const auto& on_field_ix = input_schema->GetFieldIndex(on_field_name); - const auto& by_field_ix = input_schema->GetFieldIndex(by_field_name); + const auto& on_field_ix = indices_of_on_key[j]; + const auto& by_field_ix = indices_of_by_key[j]; - if ((on_field_ix == -1) | (by_field_ix == -1)) { + if ((on_field_ix == -1) || std_has(by_field_ix, -1)) { return Status::Invalid("Missing join key on table ", j); } + const auto& on_field = input_schema->fields()[on_field_ix]; + std::vector by_field(n_by); + for (size_t k = 0; k < n_by; k++) { + by_field[k] = input_schema->fields()[by_field_ix[k]].get(); + } + + if (on_key_type == NULLPTR) { + on_key_type = on_field->type().get(); + } else if (*on_key_type != *on_field->type()) { + return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", + *on_field->type(), " for field ", on_field->name(), + " in input ", j); + } + for (size_t k = 0; k < n_by; k++) { + if (by_key_type[k] == NULLPTR) { + by_key_type[k] = by_field[k]->type().get(); + } else if (*by_key_type[k] != *by_field[k]->type()) { + return Status::Invalid("Expected on-key type ", *by_key_type[k], " but got ", + *by_field[k]->type(), " for field ", by_field[k]->name(), + " in input ", j); + } + } + for (int i = 0; i < input_schema->num_fields(); ++i) { const auto field = input_schema->field(i); - if (field->name() == on_field_name) { - if (kSupportedOnTypes_.find(field->type()) == kSupportedOnTypes_.end()) { - return Status::Invalid("Unsupported type for on key: ", field->name()); + if (i == on_field_ix) { + if (!find_type(kSupportedOnTypes_, field->type())) { + return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", + field->type()->ToString()); } // Only add on field from the left table if (j == 0) { fields.push_back(field); } - } else if (field->name() == by_field_name) { - if (kSupportedByTypes_.find(field->type()) == kSupportedByTypes_.end()) { - return Status::Invalid("Unsupported type for by key: ", field->name()); + } else if (std_has(by_field_ix, i)) { + if (!find_type(kSupportedByTypes_, field->type())) { + return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", + field->type()->ToString()); } // Only add by field from the left table if (j == 0) { fields.push_back(field); } } else { - if (kSupportedDataTypes_.find(field->type()) == kSupportedDataTypes_.end()) { - return Status::Invalid("Unsupported data type: ", field->name()); + if (!find_type(kSupportedDataTypes_, field->type())) { + return Status::Invalid("Unsupported type for field ", field->name(), " : ", + field->type()->ToString()); } fields.push_back(field); @@ -662,25 +858,67 @@ class AsofJoinNode : public ExecNode { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; const auto& join_options = checked_cast(options); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, - MakeOutputSchema(inputs, join_options)); + if (join_options.tolerance < 0) { + return Status::Invalid("AsOfJoin tolerance must be non-negative but is ", + join_options.tolerance); + } + + size_t n_input = inputs.size(), n_by = join_options.by_key.size(); + std::vector input_labels(n_input); + vec_col_index_t indices_of_on_key(n_input); + std::vector indices_of_by_key(n_input, vec_col_index_t(n_by)); + for (size_t i = 0; i < n_input; ++i) { + input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i); + const auto& input_schema = inputs[i]->output_schema(); + +#define ASOFJOIN_KEY_MATCH(kopt, kacc) \ + auto kopt##_match_res = (join_options.kopt)kacc.FindOne(*input_schema); \ + if (!kopt##_match_res.ok()) { \ + return Status::Invalid("Bad join key on table : ", \ + kopt##_match_res.status().message()); \ + } \ + auto kopt##_match = kopt##_match_res.ValueOrDie(); \ + if (kopt##_match.indices().size() != 1) { \ + return Status::Invalid("AsOfJoinNode does not support a nested " #kopt "-key ", \ + (join_options.kopt)kacc.ToString()); \ + } \ + (indices_of_##kopt[i]) kacc = kopt##_match.indices()[0]; + + ASOFJOIN_KEY_MATCH(on_key, ) + for (size_t k = 0; k < n_by; k++) { + ASOFJOIN_KEY_MATCH(by_key, [k]) + } - std::vector input_labels(inputs.size()); - input_labels[0] = "left"; - for (size_t i = 1; i < inputs.size(); ++i) { - input_labels[i] = "right_" + std::to_string(i); +#undef ASOFJOIN_KEY_MATCH } - return plan->EmplaceNode(plan, inputs, std::move(input_labels), - join_options, std::move(output_schema)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, + MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key)); + + auto node = plan->EmplaceNode( + plan, inputs, std::move(input_labels), std::move(indices_of_on_key), + std::move(indices_of_by_key), time_value(join_options.tolerance), + std::move(output_schema)); + auto node_output_schema = node->output_schema(); + auto node_indices_of_by_key = checked_cast(node)->indices_of_by_key(); + std::vector> key_hashers; + if (n_by > 1) { + for (size_t i = 0; i < n_input; i++) { + key_hashers.push_back( + ::arrow::internal::make_unique(node_indices_of_by_key[i])); + RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); + } + } + RETURN_NOT_OK(node->Init(std::move(key_hashers))); + return node; } const char* kind_name() const override { return "AsofJoinNode"; } void InputReceived(ExecNode* input, ExecBatch batch) override { // Get the input - ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); @@ -694,8 +932,8 @@ class AsofJoinNode : public ExecNode { void InputFinished(ExecNode* input, int total_batches) override { { std::lock_guard guard(gate_); - ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); state_.at(k)->set_total_batches(total_batches); } // Trigger a process call @@ -718,16 +956,19 @@ class AsofJoinNode : public ExecNode { arrow::Future<> finished() override { return finished_; } private: - static const std::set> kSupportedOnTypes_; - static const std::set> kSupportedByTypes_; - static const std::set> kSupportedDataTypes_; + static const std::unordered_set> kSupportedOnTypes_; + static const std::unordered_set> kSupportedByTypes_; + static const std::unordered_set> kSupportedDataTypes_; arrow::Future<> finished_; + std::vector> key_hashers_; + vec_col_index_t indices_of_on_key_; + std::vector indices_of_by_key_; // InputStates // Each input state correponds to an input table std::vector> state_; std::mutex gate_; - AsofJoinNodeOptions options_; + TimeType tolerance_; // Queue for triggering processing of a given input // (a false value is a poison pill) @@ -741,29 +982,62 @@ class AsofJoinNode : public ExecNode { AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema) + const vec_col_index_t& indices_of_on_key, + const std::vector& indices_of_by_key, + TimeType tolerance, std::shared_ptr output_schema) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), - options_(join_options), + indices_of_on_key_(std::move(indices_of_on_key)), + indices_of_by_key_(std::move(indices_of_by_key)), + tolerance_(tolerance), process_(), process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { - for (size_t i = 0; i < inputs.size(); ++i) - state_.push_back(::arrow::internal::make_unique( - inputs[i]->output_schema(), *options_.on_key.name(), *options_.by_key.name())); - col_index_t dst_offset = 0; - for (auto& state : state_) - dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); - finished_ = arrow::Future<>::MakeFinished(); } // Currently supported types -const std::set> AsofJoinNode::kSupportedOnTypes_ = {int64()}; -const std::set> AsofJoinNode::kSupportedByTypes_ = {int32()}; -const std::set> AsofJoinNode::kSupportedDataTypes_ = { - int32(), int64(), float32(), float64()}; +const std::unordered_set> AsofJoinNode::kSupportedOnTypes_ = { + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC")}; +const std::unordered_set> AsofJoinNode::kSupportedByTypes_ = { + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC")}; +const std::unordered_set> AsofJoinNode::kSupportedDataTypes_ = { + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC"), + float32(), + float64()}; namespace internal { void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 8b993764abe7f..65e39c748655d 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -22,6 +22,7 @@ #include #include "arrow/api.h" +#include "arrow/compute/api_scalar.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/test_util.h" #include "arrow/compute/exec/util.h" @@ -39,16 +40,50 @@ using testing::UnorderedElementsAreArray; namespace arrow { namespace compute { +// mutates by copying from_key into to_key and changing from_key to zero +BatchesWithSchema MutateByKey(const BatchesWithSchema& batches, std::string from_key, + std::string to_key, bool replace_key = false) { + int from_index = batches.schema->GetFieldIndex(from_key); + int n_fields = batches.schema->num_fields(); + BatchesWithSchema new_batches; + auto new_field = batches.schema->field(from_index)->WithName(to_key); + new_batches.schema = (replace_key ? batches.schema->SetField(from_index, new_field) + : batches.schema->AddField(from_index, new_field)) + .ValueOrDie(); + for (const ExecBatch& batch : batches.batches) { + std::vector new_values; + for (int i = 0; i < n_fields; i++) { + const Datum& value = batch.values[i]; + if (i == from_index) { + new_values.push_back(Subtract(value, value).ValueOrDie()); + if (replace_key) { + continue; + } + } + new_values.push_back(value); + } + new_batches.batches.emplace_back(new_values, batch.length); + } + return new_batches; +} + +// code generation for the by_key types supported by AsofJoinNodeOptions constructors +// which cannot be directly done using templates because of failure to deduce the template +// argument for an invocation with a string- or initializer_list-typed keys-argument +#define EXPAND_BY_KEY_TYPE(macro) \ + macro(const FieldRef); \ + macro(std::vector); \ + macro(std::initializer_list); + void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, const BatchesWithSchema& r1_batches, - const BatchesWithSchema& exp_batches, const FieldRef time, - const FieldRef keys, const int64_t tolerance) { + const BatchesWithSchema& exp_batches, + const AsofJoinNodeOptions join_options) { auto exec_ctx = arrow::internal::make_unique(default_memory_pool(), nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsofJoinNodeOptions join_options(time, keys, tolerance); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ @@ -74,206 +109,463 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, /*same_chunk_layout=*/true, /*flatten=*/true); } -void DoRunBasicTest(const std::vector& l_data, - const std::vector& r0_data, - const std::vector& r1_data, - const std::vector& exp_data, int64_t tolerance) { - auto l_schema = - schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); - auto r0_schema = - schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())}); - auto r1_schema = - schema({field("time", int64()), field("key", int32()), field("r1_v0", float32())}); - - auto exp_schema = schema({ - field("time", int64()), - field("key", int32()), - field("l_v0", float64()), - field("r0_v0", float64()), - field("r1_v0", float32()), - }); - - // Test three table join - BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; - l_batches = MakeBatchesFromString(l_schema, l_data); - r0_batches = MakeBatchesFromString(r0_schema, r0_data); - r1_batches = MakeBatchesFromString(r1_schema, r1_data); - exp_batches = MakeBatchesFromString(exp_schema, exp_data); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", - tolerance); -} - -void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema) { +#define CHECK_RUN_OUTPUT(by_key_type) \ + void CheckRunOutput( \ + const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, \ + const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \ + const FieldRef time, by_key_type keys, const int64_t tolerance) { \ + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \ + AsofJoinNodeOptions(time, keys, tolerance)); \ + } + +EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) + +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, FieldRef on_key, + FieldRef by_key, int64_t tolerance, + const std::string& expected_error_str) { BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); - AsofJoinNodeOptions join_options("time", "key", 0); + AsofJoinNodeOptions join_options(on_key, by_key, tolerance); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); - ASSERT_RAISES(Invalid, join.AddToPlan(plan.get())); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_error_str), + join.AddToPlan(plan.get())); +} + +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, int64_t tolerance, + const std::string& expected_error_str) { + DoRunInvalidPlanTest(l_schema, r_schema, "time", "key", tolerance, expected_error_str); +} + +void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Unsupported type for "); +} + +void DoRunInvalidToleranceTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, -1, + "AsOfJoin tolerance must be non-negative but is "); +} + +void DoRunMissingKeysTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No match"); +} + +void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, "invalid_time", "key", 0, + "Bad join key on table : No match"); +} + +void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, "time", "invalid_key", 0, + "Bad join key on table : No match"); } +void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, {0, "time"}, "key", 0, + "Bad join key on table : No match"); +} + +void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, "time", {0, "key"}, 0, + "Bad join key on table : No match"); +} + +void DoRunAmbiguousOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); +} + +void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); +} + +struct BasicTestTypes { + std::shared_ptr time, key, l_val, r0_val, r1_val; +}; + +struct BasicTest { + BasicTest(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_nokey_data, + const std::vector& exp_data, int64_t tolerance) + : l_data(std::move(l_data)), + r0_data(std::move(r0_data)), + r1_data(std::move(r1_data)), + exp_nokey_data(std::move(exp_nokey_data)), + exp_data(std::move(exp_data)), + tolerance(tolerance) {} + + template + static inline void init_types(const std::vector>& all_types, + std::vector>& types, + TypeCond type_cond) { + if (types.size() == 0) { + for (auto type : all_types) { + if (type_cond(type)) { + types.push_back(type); + } + } + } + } + + void Run() { + RunSingleByKey(); + RunDoubleByKey(); + } + void RunSingleByKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", + tolerance); + }); + } + void RunDoubleByKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", + {"key", "key"}, tolerance); + }); + } + void RunMutateByKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + l_batches = MutateByKey(l_batches, "key", "key2"); + r0_batches = MutateByKey(r0_batches, "key", "key2"); + r1_batches = MutateByKey(r1_batches, "key", "key2"); + exp_batches = MutateByKey(exp_batches, "key", "key2"); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", + {"key", "key2"}, tolerance); + }); + } + void RunMutateNoKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + l_batches = MutateByKey(l_batches, "key", "key2", true); + r0_batches = MutateByKey(r0_batches, "key", "key2", true); + r1_batches = MutateByKey(r1_batches, "key", "key2", true); + exp_batches = MutateByKey(exp_batches, "key", "key2", true); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, "time", "key2", + tolerance); + }); + } + template + void RunBatches(BatchesRunner batches_runner, + std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + std::vector> all_types = { + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC"), + float32(), + float64()}; + using T = const std::shared_ptr; + init_types(all_types, time_types, + [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); + init_types(all_types, key_types, [](T& t) { return is_integer(t->id()); }); + init_types(all_types, l_types, + [](T& t) { return t->byte_width() > 1 && t->id() != Type::TIMESTAMP; }); + init_types(all_types, r0_types, [](T& t) { return is_floating(t->id()); }); + init_types(all_types, r1_types, [](T& t) { return is_floating(t->id()); }); + for (auto time_type : time_types) { + for (auto key_type : key_types) { + for (auto l_type : l_types) { + for (auto r0_type : r0_types) { + for (auto r1_type : r1_types) { + RunTypes({time_type, key_type, l_type, r0_type, r1_type}, batches_runner); + } + } + } + } + } + } + template + void RunTypes(BasicTestTypes basic_test_types, BatchesRunner batches_runner) { + const BasicTestTypes& b = basic_test_types; + auto l_schema = + schema({field("time", b.time), field("key", b.key), field("l_v0", b.l_val)}); + auto r0_schema = + schema({field("time", b.time), field("key", b.key), field("r0_v0", b.r0_val)}); + auto r1_schema = + schema({field("time", b.time), field("key", b.key), field("r1_v0", b.r1_val)}); + + auto exp_schema = schema({ + field("time", b.time), + field("key", b.key), + field("l_v0", b.l_val), + field("r0_v0", b.r0_val), + field("r1_v0", b.r1_val), + }); + + // Test three table join + BatchesWithSchema l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches; + l_batches = MakeBatchesFromString(l_schema, l_data); + r0_batches = MakeBatchesFromString(r0_schema, r0_data); + r1_batches = MakeBatchesFromString(r1_schema, r1_data); + exp_nokey_batches = MakeBatchesFromString(exp_schema, exp_nokey_data); + exp_batches = MakeBatchesFromString(exp_schema, exp_data); + batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches); + } + + std::vector l_data; + std::vector r0_data; + std::vector r1_data; + std::vector exp_nokey_data; + std::vector exp_data; + int64_t tolerance; +}; + class AsofJoinTest : public testing::Test {}; -TEST(AsofJoinTest, TestBasic1) { +#define ASOFJOIN_TEST_SET(name, num) \ + TEST(AsofJoinTest, Test##name##num##_SingleByKey) { \ + Get##name##Test##num().RunSingleByKey(); \ + } \ + TEST(AsofJoinTest, Test##name##num##_DoubleByKey) { \ + Get##name##Test##num().RunDoubleByKey(); \ + } \ + TEST(AsofJoinTest, Test##name##num##_MutateByKey) { \ + Get##name##Test##num().RunMutateByKey(); \ + } \ + TEST(AsofJoinTest, Test##name##num##_MutateNoKey) { \ + Get##name##Test##num().RunMutateNoKey(); \ + } + +BasicTest GetBasicTest1() { // Single key, single batch - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0], [1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0]])"}, - /*r1*/ {R"([[1000, 1, 101.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1], [1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11]])"}, + /*r1*/ {R"([[1000, 1, 101]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, null], [1000, 0, 2, 11, 101]])"}, + /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, 1000); } +ASOFJOIN_TEST_SET(Basic, 1) -TEST(AsofJoinTest, TestBasic2) { +BasicTest GetBasicTest2() { // Single key, multiple batches - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"}, - /*r1*/ {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11]])", R"([[1000, 1, 12]])"}, + /*r1*/ {R"([[0, 1, 101]])", R"([[1000, 1, 102]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, + /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } +ASOFJOIN_TEST_SET(Basic, 2) -TEST(AsofJoinTest, TestBasic3) { +BasicTest GetBasicTest3() { // Single key, multiple left batches, single right batches - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0], [1000, 1, 12.0]])"}, - /*r1*/ {R"([[0, 1, 101.0], [1000, 1, 102.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11], [1000, 1, 12]])"}, + /*r1*/ {R"([[0, 1, 101], [1000, 1, 102]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, + /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } +ASOFJOIN_TEST_SET(Basic, 3) -TEST(AsofJoinTest, TestBasic4) { +BasicTest GetBasicTest4() { // Multi key, multiple batches, misaligned batches - DoRunBasicTest( + return BasicTest( /*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, 1001], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } +ASOFJOIN_TEST_SET(Basic, 4) -TEST(AsofJoinTest, TestBasic5) { +BasicTest GetBasicTest5() { // Multi key, multiple batches, misaligned batches, smaller tolerance - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 500); -} - -TEST(AsofJoinTest, TestBasic6) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, null], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, + 500); +} +ASOFJOIN_TEST_SET(Basic, 5) + +BasicTest GetBasicTest6() { // Multi key, multiple batches, misaligned batches, zero tolerance - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null, null], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null, null]])"}, - 0); -} - -TEST(AsofJoinTest, TestEmpty1) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, null], [1500, 1, 3, null, null], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"}, + 0); +} +ASOFJOIN_TEST_SET(Basic, 6) + +BasicTest GetEmptyTest1() { // Empty left batch - DoRunBasicTest(/*l*/ - {R"([])", R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty2) { + return BasicTest(/*l*/ + {R"([])", R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp*/ + {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); +} +ASOFJOIN_TEST_SET(Empty, 1) + +BasicTest GetEmptyTest2() { // Empty left input - DoRunBasicTest(/*l*/ - {R"([])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([])"}, 1000); -} - -TEST(AsofJoinTest, TestEmpty3) { + return BasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); +} +ASOFJOIN_TEST_SET(Empty, 2) + +BasicTest GetEmptyTest3() { // Empty right batch - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([])", R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty4) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([])", R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, + 1000); +} +ASOFJOIN_TEST_SET(Empty, 3) + +BasicTest GetEmptyTest4() { // Empty right input - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, null, 1002.0]])", - R"([[2000, 1, 4.0, null, 103.0], [2000, 2, 24.0, null, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty5) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, null, 1002], [1500, 0, 23, null, 1002]])", + R"([[2000, 0, 4, null, 103], [2000, 0, 24, null, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, null, 1002]])", + R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"}, + 1000); +} +ASOFJOIN_TEST_SET(Empty, 4) + +BasicTest GetEmptyTest5() { // All empty - DoRunBasicTest(/*l*/ - {R"([])"}, - /*r0*/ - {R"([])"}, - /*r1*/ - {R"([])"}, - /*exp*/ - {R"([])"}, 1000); + return BasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([])"}, + /*exp_nokey*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); } +ASOFJOIN_TEST_SET(Empty, 5) TEST(AsofJoinTest, TestUnsupportedOntype) { DoRunInvalidTypeTest( @@ -295,16 +587,61 @@ TEST(AsofJoinTest, TestUnsupportedDatatype) { } TEST(AsofJoinTest, TestMissingKeys) { - DoRunInvalidTypeTest( + DoRunMissingKeysTest( schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}), schema( {field("time1", int64()), field("key", int32()), field("r0_v0", float64())})); - DoRunInvalidTypeTest( + DoRunMissingKeysTest( schema({field("time", int64()), field("key1", int32()), field("l_v0", float64())}), schema( {field("time", int64()), field("key1", int32()), field("r0_v0", float64())})); } +TEST(AsofJoinTest, TestUnsupportedTolerance) { + // Utf8 is unsupported + DoRunInvalidToleranceTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestMissingOnKey) { + DoRunMissingOnKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestMissingByKey) { + DoRunMissingByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestNestedOnKey) { + DoRunNestedOnKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestNestedByKey) { + DoRunNestedByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestAmbiguousOnKey) { + DoRunAmbiguousOnKeyTest( + schema({field("time", int64()), field("time", int64()), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TEST(AsofJoinTest, TestAmbiguousByKey) { + DoRunAmbiguousByKeyTest( + schema({field("time", int64()), field("key", int64()), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 5cf66b3d09e48..da1710fe08dac 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -26,7 +26,6 @@ #include #include "arrow/compute/exec/hash_join_dict.h" -#include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/task_util.h" #include "arrow/compute/kernels/row_encoder.h" #include "arrow/compute/row/encode_internal.h" diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index a8e8c1ee23096..232ff3f5180b6 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -397,23 +397,33 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(std::move(by_key)), tolerance(tolerance) {} + AsofJoinNodeOptions(FieldRef on_key, const FieldRef& by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(), tolerance(tolerance) { + this->by_key.push_back(std::move(by_key)); + } + + AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} + + // resolves ambiguity between previous constructors when initializer list is given + AsofJoinNodeOptions(FieldRef on_key, std::initializer_list by_key, + int64_t tolerance) + : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} /// \brief "on" key for the join. Each /// - /// All inputs tables must be sorted by the "on" key. Inexact - /// match is used on the "on" key. i.e., a row is considiered match iff + /// All inputs tables must be sorted by the "on" key. Must be a single field of a common + /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff /// left_on - tolerance <= right_on <= left_on. - /// Currently, "on" key must be an int64 field + /// Currently, the "on" key must be of an integer or timestamp type FieldRef on_key; /// \brief "by" key for the join. /// /// All input tables must have the "by" key. Exact equality /// is used for the "by" key. - /// Currently, the "by" key must be an int32 field - FieldRef by_key; - /// Tolerance for inexact "on" key matching + /// Currently, the "by" key must be of an integer or timestamp type + std::vector by_key; + /// Tolerance for inexact "on" key matching. Must be non-negative. int64_t tolerance; }; diff --git a/cpp/src/arrow/compute/light_array.cc b/cpp/src/arrow/compute/light_array.cc index 4bf3574d09fdb..9ea609c531048 100644 --- a/cpp/src/arrow/compute/light_array.cc +++ b/cpp/src/arrow/compute/light_array.cc @@ -141,6 +141,12 @@ Result ColumnArrayFromArrayData( const std::shared_ptr& array_data, int64_t start_row, int64_t num_rows) { ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata metadata, ColumnMetadataFromDataType(array_data->type)); + return ColumnArrayFromArrayDataAndMetadata(array_data, metadata, start_row, num_rows); +} + +KeyColumnArray ColumnArrayFromArrayDataAndMetadata( + const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, + int64_t start_row, int64_t num_rows) { KeyColumnArray column_array = KeyColumnArray( metadata, array_data->offset + start_row + num_rows, array_data->buffers[0] != NULLPTR ? array_data->buffers[0]->data() : nullptr, diff --git a/cpp/src/arrow/compute/light_array.h b/cpp/src/arrow/compute/light_array.h index f0e5c7068716a..6a5c205a7a48b 100644 --- a/cpp/src/arrow/compute/light_array.h +++ b/cpp/src/arrow/compute/light_array.h @@ -187,6 +187,17 @@ ARROW_EXPORT Result ColumnMetadataFromDataType( ARROW_EXPORT Result ColumnArrayFromArrayData( const std::shared_ptr& array_data, int64_t start_row, int64_t num_rows); +/// \brief Create KeyColumnArray from ArrayData and KeyColumnMetadata +/// +/// If `type` is a dictionary type then this will return the KeyColumnArray for +/// the indices array +/// +/// The caller should ensure this is only called on "key" columns. +/// \see ColumnMetadataFromDataType for details +ARROW_EXPORT KeyColumnArray ColumnArrayFromArrayDataAndMetadata( + const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, + int64_t start_row, int64_t num_rows); + /// \brief Create KeyColumnMetadata instances from an ExecBatch /// /// column_metadatas will be resized to fit From d4f06bdfa114efaf3179af0e5ec113b06ca9d203 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 15 Aug 2022 09:25:50 -0400 Subject: [PATCH 07/26] fix method overload conflict --- cpp/src/arrow/compute/exec/asof_join_node.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 2319ba8249910..e1378a0a4f545 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -744,7 +744,7 @@ class AsofJoinNode : public ExecNode { const std::vector& indices_of_by_key, TimeType tolerance, std::shared_ptr output_schema); - Status Init(std::vector> key_hashers) { + Status InternalInit(std::vector> key_hashers) { key_hashers_.swap(key_hashers); bool has_kp = key_hashers_.size() > 0; auto inputs = this->inputs(); @@ -909,7 +909,7 @@ class AsofJoinNode : public ExecNode { RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); } } - RETURN_NOT_OK(node->Init(std::move(key_hashers))); + RETURN_NOT_OK(node->InternalInit(std::move(key_hashers))); return node; } From f37b3c658242f7c888ef98f8de336fb21501e543 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 17 Aug 2022 10:58:24 -0400 Subject: [PATCH 08/26] AsofJoin additional temporal types, tests, cleanups --- cpp/src/arrow/compute/exec/asof_join_node.cc | 133 ++++++++++++------ .../arrow/compute/exec/asof_join_node_test.cc | 78 +++++++--- 2 files changed, 148 insertions(+), 63 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 2319ba8249910..8ba0a589b9ead 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -31,10 +31,12 @@ #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/make_unique.h" #include "arrow/util/optional.h" +#include "arrow/util/string_view.h" namespace arrow { namespace compute { @@ -49,8 +51,8 @@ inline bool std_has(const T& container, const V& val) { return container.end() != std_find(container, val); } -typedef uint64_t KeyType; -typedef uint64_t TimeType; +typedef uint64_t ByType; +typedef uint64_t OnType; typedef uint64_t HashType; // Maximum number of tables that can be joined @@ -59,17 +61,20 @@ typedef uint64_t row_index_t; typedef int col_index_t; typedef std::vector vec_col_index_t; +// normalize the value to 64-bits while preserving ordering of values template ::value, bool> = true> static inline uint64_t norm_value(T t) { uint64_t bias = std::is_signed::value ? (uint64_t)1 << (8 * sizeof(T) - 1) : 0; return t < 0 ? static_cast(t + bias) : static_cast(t); } +// indicates normalization of a time value template ::value, bool> = true> static inline uint64_t time_value(T t) { return norm_value(t); } +// indicates normalization of a key value template ::value, bool> = true> static inline uint64_t key_value(T t) { return norm_value(t); @@ -129,7 +134,7 @@ struct MemoStore { struct Entry { // Timestamp associated with the entry - TimeType time; + OnType time; // Batch associated with the entry (perf is probably OK for this; batches change // rarely) @@ -139,10 +144,10 @@ struct MemoStore { row_index_t row; }; - std::unordered_map entries_; + std::unordered_map entries_; - void Store(const std::shared_ptr& batch, row_index_t row, TimeType time, - KeyType key) { + void Store(const std::shared_ptr& batch, row_index_t row, OnType time, + ByType key) { auto& e = entries_[key]; // that we can do this assignment optionally, is why we // can get array with using shared_ptr above (the batch @@ -152,13 +157,13 @@ struct MemoStore { e.time = time; } - util::optional GetEntryForKey(KeyType key) const { + util::optional GetEntryForKey(ByType key) const { auto e = entries_.find(key); if (entries_.end() == e) return util::nullopt; return util::optional(&e->second); } - void RemoveEntriesWithLesserTime(TimeType ts) { + void RemoveEntriesWithLesserTime(OnType ts) { for (auto e = entries_.begin(); e != entries_.end();) if (e->second.time < ts) e = entries_.erase(e); @@ -167,7 +172,8 @@ struct MemoStore { } }; -// a specialized higher-performance variation of Hashing64 logic +// a specialized higher-performance variation of Hashing64 logic from hash_join_node +// the code here avoids recreating objects that are independent of each batch processed class KeyHasher { static constexpr int kMiniBatchLength = util::MiniBatch::kMiniBatchLength; @@ -287,7 +293,7 @@ class InputState { return val(data->GetValues(1)[latest_ref_row_]); \ } - KeyType GetLatestKey() const { + ByType GetLatestKey() const { if (key_hasher_ != NULLPTR) { return key_hasher_->HashesFor(queue_.UnsyncFront().get())[latest_ref_row_]; } @@ -301,12 +307,16 @@ class InputState { LATEST_VAL_CASE(UINT16, key_value) LATEST_VAL_CASE(UINT32, key_value) LATEST_VAL_CASE(UINT64, key_value) + LATEST_VAL_CASE(DATE32, key_value) + LATEST_VAL_CASE(DATE64, key_value) + LATEST_VAL_CASE(TIME32, key_value) + LATEST_VAL_CASE(TIME64, key_value) default: return 0; // cannot happen } } - TimeType GetLatestTime() const { + OnType GetLatestTime() const { auto data = queue_.UnsyncFront()->column_data(time_col_index_); switch (time_type_id_) { LATEST_VAL_CASE(INT8, time_value) @@ -317,6 +327,10 @@ class InputState { LATEST_VAL_CASE(UINT16, time_value) LATEST_VAL_CASE(UINT32, time_value) LATEST_VAL_CASE(UINT64, time_value) + LATEST_VAL_CASE(DATE32, time_value) + LATEST_VAL_CASE(DATE64, time_value) + LATEST_VAL_CASE(TIME32, time_value) + LATEST_VAL_CASE(TIME64, time_value) LATEST_VAL_CASE(TIMESTAMP, time_value) default: return 0; // cannot happen @@ -351,7 +365,7 @@ class InputState { // latest_time and latest_ref_row to the value that immediately pass the // specified timestamp. // Returns true if updates were made, false if not. - bool AdvanceAndMemoize(TimeType ts) { + bool AdvanceAndMemoize(OnType ts) { // Advance the right side row index until we reach the latest right row (for each key) // for the given left timestamp. @@ -383,11 +397,11 @@ class InputState { } } - util::optional GetMemoEntryForKey(KeyType key) { + util::optional GetMemoEntryForKey(ByType key) { return memo_.GetEntryForKey(key); } - util::optional GetMemoTimeForKey(KeyType key) { + util::optional GetMemoTimeForKey(ByType key) { auto r = GetMemoEntryForKey(key); if (r.has_value()) { return (*r)->time; @@ -396,7 +410,7 @@ class InputState { } } - void RemoveMemoEntriesWithLesserTime(TimeType ts) { + void RemoveMemoEntriesWithLesserTime(OnType ts) { memo_.RemoveEntriesWithLesserTime(ts); } @@ -468,18 +482,18 @@ class CompositeReferenceTable { // Adds the latest row from the input state as a new composite reference row // - LHS must have a valid key,timestep,and latest rows // - RHS must have valid data memo'ed for the key - void Emplace(std::vector>& in, TimeType tolerance) { + void Emplace(std::vector>& in, OnType tolerance) { DCHECK_EQ(in.size(), n_tables_); // Get the LHS key - KeyType key = in[0]->GetLatestKey(); + ByType key = in[0]->GetLatestKey(); // Add row and setup LHS // (the LHS state comes just from the latest row of the LHS table) DCHECK(!in[0]->Empty()); const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); row_index_t lhs_latest_row = in[0]->GetLatestRow(); - TimeType lhs_latest_time = in[0]->GetLatestTime(); + OnType lhs_latest_time = in[0]->GetLatestTime(); if (0 == lhs_latest_row) { // On the first row of the batch, we resize the destination. // The destination size is dictated by the size of the LHS batch. @@ -559,6 +573,10 @@ class CompositeReferenceTable { ASOFJOIN_MATERIALIZE_CASE(UINT64) ASOFJOIN_MATERIALIZE_CASE(FLOAT) ASOFJOIN_MATERIALIZE_CASE(DOUBLE) + ASOFJOIN_MATERIALIZE_CASE(DATE32) + ASOFJOIN_MATERIALIZE_CASE(DATE64) + ASOFJOIN_MATERIALIZE_CASE(TIME32) + ASOFJOIN_MATERIALIZE_CASE(TIME64) ASOFJOIN_MATERIALIZE_CASE(TIMESTAMP) default: return Status::Invalid("Unsupported data type ", @@ -634,7 +652,7 @@ class AsofJoinNode : public ExecNode { bool IsUpToDateWithLhsRow() const { auto& lhs = *state_[0]; if (lhs.Empty()) return false; // can't proceed if nothing on the LHS - TimeType lhs_ts = lhs.GetLatestTime(); + OnType lhs_ts = lhs.GetLatestTime(); for (size_t i = 1; i < state_.size(); ++i) { auto& rhs = *state_[i]; if (!rhs.Finished()) { @@ -741,16 +759,16 @@ class AsofJoinNode : public ExecNode { public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const vec_col_index_t& indices_of_on_key, - const std::vector& indices_of_by_key, TimeType tolerance, + const std::vector& indices_of_by_key, OnType tolerance, std::shared_ptr output_schema); - Status Init(std::vector> key_hashers) { + Status InternalInit(std::vector> key_hashers) { key_hashers_.swap(key_hashers); - bool has_kp = key_hashers_.size() > 0; + bool has_hashers = key_hashers_.size() > 0; auto inputs = this->inputs(); for (size_t i = 0; i < inputs.size(); ++i) { state_.push_back(::arrow::internal::make_unique( - has_kp ? key_hashers_[i].get() : NULLPTR, inputs[i]->output_schema(), + has_hashers ? key_hashers_[i].get() : NULLPTR, inputs[i]->output_schema(), indices_of_on_key_[i], indices_of_by_key_[i])); } @@ -853,11 +871,29 @@ class AsofJoinNode : public ExecNode { return std::make_shared(fields); } + static inline Result FindColIndex(const Schema& schema, + const FieldRef& field_ref, + const util::string_view& key_kind) { + auto match_res = field_ref.FindOne(schema); + if (!match_res.ok()) { + return Status::Invalid("Bad join key on table : ", match_res.status().message()); + } + auto match = match_res.ValueOrDie(); + if (match.indices().size() != 1) { + return Status::Invalid("AsOfJoinNode does not support a nested ", + to_string(key_kind), "-key ", field_ref.ToString()); + } + return match.indices()[0]; + } + static arrow::Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; const auto& join_options = checked_cast(options); + if (join_options.by_key.size() == 0) { + return Status::Invalid("AsOfJoin by_key must not be empty"); + } if (join_options.tolerance < 0) { return Status::Invalid("AsOfJoin tolerance must be non-negative but is ", join_options.tolerance); @@ -869,27 +905,13 @@ class AsofJoinNode : public ExecNode { std::vector indices_of_by_key(n_input, vec_col_index_t(n_by)); for (size_t i = 0; i < n_input; ++i) { input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i); - const auto& input_schema = inputs[i]->output_schema(); - -#define ASOFJOIN_KEY_MATCH(kopt, kacc) \ - auto kopt##_match_res = (join_options.kopt)kacc.FindOne(*input_schema); \ - if (!kopt##_match_res.ok()) { \ - return Status::Invalid("Bad join key on table : ", \ - kopt##_match_res.status().message()); \ - } \ - auto kopt##_match = kopt##_match_res.ValueOrDie(); \ - if (kopt##_match.indices().size() != 1) { \ - return Status::Invalid("AsOfJoinNode does not support a nested " #kopt "-key ", \ - (join_options.kopt)kacc.ToString()); \ - } \ - (indices_of_##kopt[i]) kacc = kopt##_match.indices()[0]; - - ASOFJOIN_KEY_MATCH(on_key, ) + const Schema& input_schema = *inputs[i]->output_schema(); + ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], + FindColIndex(input_schema, join_options.on_key, "on")); for (size_t k = 0; k < n_by; k++) { - ASOFJOIN_KEY_MATCH(by_key, [k]) + ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], + FindColIndex(input_schema, join_options.by_key[k], "by")); } - -#undef ASOFJOIN_KEY_MATCH } ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, @@ -901,15 +923,16 @@ class AsofJoinNode : public ExecNode { std::move(output_schema)); auto node_output_schema = node->output_schema(); auto node_indices_of_by_key = checked_cast(node)->indices_of_by_key(); + auto single_key_field = inputs[0]->output_schema()->field(indices_of_by_key[0][0]); std::vector> key_hashers; - if (n_by > 1) { + if (n_by > 1 || is_primitive(single_key_field->type()->id())) { for (size_t i = 0; i < n_input; i++) { key_hashers.push_back( ::arrow::internal::make_unique(node_indices_of_by_key[i])); RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); } } - RETURN_NOT_OK(node->Init(std::move(key_hashers))); + RETURN_NOT_OK(node->InternalInit(std::move(key_hashers))); return node; } @@ -968,7 +991,7 @@ class AsofJoinNode : public ExecNode { // Each input state correponds to an input table std::vector> state_; std::mutex gate_; - TimeType tolerance_; + OnType tolerance_; // Queue for triggering processing of a given input // (a false value is a poison pill) @@ -984,7 +1007,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const vec_col_index_t& indices_of_on_key, const std::vector& indices_of_by_key, - TimeType tolerance, std::shared_ptr output_schema) + OnType tolerance, std::shared_ptr output_schema) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), @@ -1006,6 +1029,12 @@ const std::unordered_set> AsofJoinNode::kSupportedOnTy uint16(), uint32(), uint64(), + date32(), + date64(), + time32(TimeUnit::MILLI), + time32(TimeUnit::SECOND), + time64(TimeUnit::NANO), + time64(TimeUnit::MICRO), timestamp(TimeUnit::NANO, "UTC"), timestamp(TimeUnit::MICRO, "UTC"), timestamp(TimeUnit::MILLI, "UTC"), @@ -1019,6 +1048,12 @@ const std::unordered_set> AsofJoinNode::kSupportedByTy uint16(), uint32(), uint64(), + date32(), + date64(), + time32(TimeUnit::MILLI), + time32(TimeUnit::SECOND), + time64(TimeUnit::NANO), + time64(TimeUnit::MICRO), timestamp(TimeUnit::NANO, "UTC"), timestamp(TimeUnit::MICRO, "UTC"), timestamp(TimeUnit::MILLI, "UTC"), @@ -1032,6 +1067,12 @@ const std::unordered_set> AsofJoinNode::kSupportedData uint16(), uint32(), uint64(), + date32(), + date64(), + time32(TimeUnit::MILLI), + time32(TimeUnit::SECOND), + time64(TimeUnit::NANO), + time64(TimeUnit::MICRO), timestamp(TimeUnit::NANO, "UTC"), timestamp(TimeUnit::MICRO, "UTC"), timestamp(TimeUnit::MILLI, "UTC"), diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 65e39c748655d..c539d26ede65d 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -33,6 +33,7 @@ #include "arrow/testing/random.h" #include "arrow/util/checked_cast.h" #include "arrow/util/make_unique.h" +#include "arrow/util/string_view.h" #include "arrow/util/thread_pool.h" using testing::UnorderedElementsAreArray; @@ -40,11 +41,34 @@ using testing::UnorderedElementsAreArray; namespace arrow { namespace compute { +bool is_temporal_primitive(Type::type type_id) { + switch (type_id) { + case Type::TIME32: + case Type::TIME64: + case Type::DATE32: + case Type::DATE64: + case Type::TIMESTAMP: + return true; + default: + return false; + } +} + +void BuildZeroPrimitiveArray(std::shared_ptr& empty, + std::shared_ptr type, int64_t length) { + ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool())); + ASSERT_OK(builder->Reserve(length)); + ASSERT_OK_AND_ASSIGN(auto scalar, MakeScalar(type, 0)); + ASSERT_OK(builder->AppendScalar(*scalar, length)); + ASSERT_OK(builder->Finish(&empty)); +} + // mutates by copying from_key into to_key and changing from_key to zero -BatchesWithSchema MutateByKey(const BatchesWithSchema& batches, std::string from_key, +BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, std::string to_key, bool replace_key = false) { int from_index = batches.schema->GetFieldIndex(from_key); int n_fields = batches.schema->num_fields(); + auto fields = batches.schema->fields(); BatchesWithSchema new_batches; auto new_field = batches.schema->field(from_index)->WithName(to_key); new_batches.schema = (replace_key ? batches.schema->SetField(from_index, new_field) @@ -55,7 +79,14 @@ BatchesWithSchema MutateByKey(const BatchesWithSchema& batches, std::string from for (int i = 0; i < n_fields; i++) { const Datum& value = batch.values[i]; if (i == from_index) { - new_values.push_back(Subtract(value, value).ValueOrDie()); + auto type = fields[i]->type(); + if (is_primitive(type->id())) { + std::shared_ptr empty; + BuildZeroPrimitiveArray(empty, type, batch.length); + new_values.push_back(empty); + } else { + new_values.push_back(Subtract(value, value).ValueOrDie()); + } if (replace_key) { continue; } @@ -121,8 +152,8 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema, FieldRef on_key, - FieldRef by_key, int64_t tolerance, + const std::shared_ptr& r_schema, + const AsofJoinNodeOptions& join_options, const std::string& expected_error_str) { BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); @@ -130,7 +161,6 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); - AsofJoinNodeOptions join_options(on_key, by_key, tolerance); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); @@ -144,7 +174,8 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema, int64_t tolerance, const std::string& expected_error_str) { - DoRunInvalidPlanTest(l_schema, r_schema, "time", "key", tolerance, expected_error_str); + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", "key", tolerance), + expected_error_str); } void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, @@ -163,27 +194,33 @@ void DoRunMissingKeysTest(const std::shared_ptr& l_schema, DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No match"); } +void DoRunEmptyByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", {}, 0), + "AsOfJoin by_key must not be empty"); +} + void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, "invalid_time", "key", 0, + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("invalid_time", "key", 0), "Bad join key on table : No match"); } void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, "time", "invalid_key", 0, + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", "invalid_key", 0), "Bad join key on table : No match"); } void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, {0, "time"}, "key", 0, + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, "key", 0), "Bad join key on table : No match"); } void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, "time", {0, "key"}, 0, + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", FieldRef{0, 1}, 0), "Bad join key on table : No match"); } @@ -227,10 +264,6 @@ struct BasicTest { } } - void Run() { - RunSingleByKey(); - RunDoubleByKey(); - } void RunSingleByKey(std::vector> time_types = {}, std::vector> key_types = {}, std::vector> l_types = {}, @@ -303,6 +336,12 @@ struct BasicTest { uint16(), uint32(), uint64(), + date32(), + date64(), + time32(TimeUnit::MILLI), + time32(TimeUnit::SECOND), + time64(TimeUnit::NANO), + time64(TimeUnit::MICRO), timestamp(TimeUnit::NANO, "UTC"), timestamp(TimeUnit::MICRO, "UTC"), timestamp(TimeUnit::MILLI, "UTC"), @@ -312,9 +351,8 @@ struct BasicTest { using T = const std::shared_ptr; init_types(all_types, time_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); - init_types(all_types, key_types, [](T& t) { return is_integer(t->id()); }); - init_types(all_types, l_types, - [](T& t) { return t->byte_width() > 1 && t->id() != Type::TIMESTAMP; }); + init_types(all_types, key_types, [](T& t) { return !is_floating(t->id()); }); + init_types(all_types, l_types, [](T& t) { return is_temporal_primitive(t->id()); }); init_types(all_types, r0_types, [](T& t) { return is_floating(t->id()); }); init_types(all_types, r1_types, [](T& t) { return is_floating(t->id()); }); for (auto time_type : time_types) { @@ -605,6 +643,12 @@ TEST(AsofJoinTest, TestUnsupportedTolerance) { schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } +TEST(AsofJoinTest, TestEmptyByKey) { + DoRunEmptyByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + TEST(AsofJoinTest, TestMissingOnKey) { DoRunMissingOnKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), From aab623d8df393ea6d2d61131c1bd1c117cca2cbb Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 21 Aug 2022 08:21:40 -0400 Subject: [PATCH 09/26] AsofJoin var-binary types with tests, Hashing32/64 large offsets --- cpp/src/arrow/compute/exec/asof_join_node.cc | 49 +++- .../arrow/compute/exec/asof_join_node_test.cc | 211 +++++++++++++----- cpp/src/arrow/compute/exec/key_hash.cc | 13 +- cpp/src/arrow/compute/light_array.h | 18 ++ 4 files changed, 228 insertions(+), 63 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 8ba0a589b9ead..78b12ecd4b475 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -21,6 +21,7 @@ #include #include +#include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/key_hash.h" @@ -312,6 +313,7 @@ class InputState { LATEST_VAL_CASE(TIME32, key_value) LATEST_VAL_CASE(TIME64, key_value) default: + DCHECK(false); return 0; // cannot happen } } @@ -333,6 +335,7 @@ class InputState { LATEST_VAL_CASE(TIME64, time_value) LATEST_VAL_CASE(TIMESTAMP, time_value) default: + DCHECK(false); return 0; // cannot happen } } @@ -578,6 +581,10 @@ class CompositeReferenceTable { ASOFJOIN_MATERIALIZE_CASE(TIME32) ASOFJOIN_MATERIALIZE_CASE(TIME64) ASOFJOIN_MATERIALIZE_CASE(TIMESTAMP) + ASOFJOIN_MATERIALIZE_CASE(STRING) + ASOFJOIN_MATERIALIZE_CASE(LARGE_STRING) + ASOFJOIN_MATERIALIZE_CASE(BINARY) + ASOFJOIN_MATERIALIZE_CASE(LARGE_BINARY) default: return Status::Invalid("Unsupported data type ", src_field->type()->ToString(), " for field ", @@ -614,8 +621,33 @@ class CompositeReferenceTable { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } - template ::BuilderType, - class PrimitiveType = typename TypeTraits::CType> + template + using is_fixed_width_type = std::is_base_of; + + template + using enable_if_fixed_width_type = enable_if_t::value, R>; + + template ::BuilderType> + enable_if_fixed_width_type BuilderAppend(Builder& builder, + const std::shared_ptr& source, + row_index_t row) { + using CType = typename TypeTraits::CType; + builder.UnsafeAppend(source->template GetValues(1)[row]); + } + + template ::BuilderType> + enable_if_base_binary BuilderAppend(Builder& builder, + const std::shared_ptr& source, + row_index_t row) { + using offset_type = typename Type::offset_type; + const uint8_t* data = source->buffers[2]->data(); + const offset_type* offsets = source->GetValues(1); + const offset_type offset0 = offsets[row]; + const offset_type offset1 = offsets[row + 1]; + builder.Append(data + offset0, offset1 - offset0); + } + + template ::BuilderType> Result> MaterializeColumn(MemoryPool* memory_pool, const std::shared_ptr& type, size_t i_table, col_index_t i_col) { @@ -625,8 +657,7 @@ class CompositeReferenceTable { for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; if (ref.batch) { - builder.UnsafeAppend( - ref.batch->column_data(i_col)->template GetValues(1)[ref.row]); + BuilderAppend(builder, ref.batch->column_data(i_col), ref.row); } else { builder.UnsafeAppendNull(); } @@ -925,7 +956,7 @@ class AsofJoinNode : public ExecNode { auto node_indices_of_by_key = checked_cast(node)->indices_of_by_key(); auto single_key_field = inputs[0]->output_schema()->field(indices_of_by_key[0][0]); std::vector> key_hashers; - if (n_by > 1 || is_primitive(single_key_field->type()->id())) { + if (n_by > 1 || !is_primitive(single_key_field->type()->id())) { for (size_t i = 0; i < n_input; i++) { key_hashers.push_back( ::arrow::internal::make_unique(node_indices_of_by_key[i])); @@ -1040,6 +1071,10 @@ const std::unordered_set> AsofJoinNode::kSupportedOnTy timestamp(TimeUnit::MILLI, "UTC"), timestamp(TimeUnit::SECOND, "UTC")}; const std::unordered_set> AsofJoinNode::kSupportedByTypes_ = { + utf8(), + large_utf8(), + binary(), + large_binary(), int8(), int16(), int32(), @@ -1059,6 +1094,10 @@ const std::unordered_set> AsofJoinNode::kSupportedByTy timestamp(TimeUnit::MILLI, "UTC"), timestamp(TimeUnit::SECOND, "UTC")}; const std::unordered_set> AsofJoinNode::kSupportedDataTypes_ = { + utf8(), + large_utf8(), + binary(), + large_binary(), int8(), int16(), int32(), diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index c539d26ede65d..cff9a6111ad10 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -17,6 +17,7 @@ #include +#include #include #include #include @@ -36,6 +37,14 @@ #include "arrow/util/string_view.h" #include "arrow/util/thread_pool.h" +#define TRACED_TEST(t_class, t_name) \ + static void _##t_class##_##t_name(); \ + TEST(t_class, t_name) { \ + ARROW_SCOPED_TRACE(#t_class "_" #t_name); \ + _##t_class##_##t_name(); \ + } \ + static void _##t_class##_##t_name() + using testing::UnorderedElementsAreArray; namespace arrow { @@ -54,8 +63,41 @@ bool is_temporal_primitive(Type::type type_id) { } } +BatchesWithSchema MakeBatchesFromNumString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity = 1) { + FieldVector num_fields; + for (auto field : schema->fields()) { + num_fields.push_back( + is_base_binary_like(field->type()->id()) ? field->WithType(int64()) : field); + } + auto num_schema = + std::make_shared(num_fields, schema->endianness(), schema->metadata()); + BatchesWithSchema num_batches = + MakeBatchesFromString(num_schema, json_strings, multiplicity); + BatchesWithSchema batches; + batches.schema = schema; + int n_fields = schema->num_fields(); + for (auto num_batch : num_batches.batches) { + std::vector values; + for (int i = 0; i < n_fields; i++) { + auto type = schema->field(i)->type(); + if (is_base_binary_like(type->id())) { + // casting to string first enables casting to binary + Datum as_string = Cast(num_batch.values[i], utf8()).ValueOrDie(); + values.push_back(Cast(as_string, type).ValueOrDie()); + } else { + values.push_back(num_batch.values[i]); + } + } + ExecBatch batch(values, num_batch.length); + batches.batches.push_back(batch); + } + return batches; +} + void BuildZeroPrimitiveArray(std::shared_ptr& empty, - std::shared_ptr type, int64_t length) { + const std::shared_ptr& type, int64_t length) { ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool())); ASSERT_OK(builder->Reserve(length)); ASSERT_OK_AND_ASSIGN(auto scalar, MakeScalar(type, 0)); @@ -63,6 +105,16 @@ void BuildZeroPrimitiveArray(std::shared_ptr& empty, ASSERT_OK(builder->Finish(&empty)); } +template +void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { + Builder builder(default_memory_pool()); + ASSERT_OK(builder.Reserve(length)); + for (int64_t i = 0; i < length; i++) { + ASSERT_OK(builder.Append("0", /*length=*/1)); + } + ASSERT_OK(builder.Finish(&empty)); +} + // mutates by copying from_key into to_key and changing from_key to zero BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, std::string to_key, bool replace_key = false) { @@ -84,6 +136,26 @@ BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, std::shared_ptr empty; BuildZeroPrimitiveArray(empty, type, batch.length); new_values.push_back(empty); + } else if (is_base_binary_like(type->id())) { + std::shared_ptr empty; + switch (type->id()) { + case Type::STRING: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::LARGE_STRING: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::BINARY: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::LARGE_BINARY: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + default: + DCHECK(false); + break; + } + new_values.push_back(empty); } else { new_values.push_back(Subtract(value, value).ValueOrDie()); } @@ -95,6 +167,10 @@ BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, } new_batches.batches.emplace_back(new_values, batch.length); } + /*std::cerr << "new schema: " << new_batches.schema->ToString() << std::endl; + for (const ExecBatch& new_batch : new_batches.batches) { + std::cerr << "new batch: " << new_batch.ToString() << std::endl; + }*/ return new_batches; } @@ -155,8 +231,8 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema, const AsofJoinNodeOptions& join_options, const std::string& expected_error_str) { - BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); - BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); + BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {R"([])"}); + BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {R"([])"}); ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); @@ -328,6 +404,10 @@ struct BasicTest { std::vector> r0_types = {}, std::vector> r1_types = {}) { std::vector> all_types = { + utf8(), + large_utf8(), + binary(), + large_binary(), int8(), int16(), int32(), @@ -351,20 +431,40 @@ struct BasicTest { using T = const std::shared_ptr; init_types(all_types, time_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); - init_types(all_types, key_types, [](T& t) { return !is_floating(t->id()); }); - init_types(all_types, l_types, [](T& t) { return is_temporal_primitive(t->id()); }); - init_types(all_types, r0_types, [](T& t) { return is_floating(t->id()); }); - init_types(all_types, r1_types, [](T& t) { return is_floating(t->id()); }); - for (auto time_type : time_types) { - for (auto key_type : key_types) { - for (auto l_type : l_types) { - for (auto r0_type : r0_types) { - for (auto r1_type : r1_types) { - RunTypes({time_type, key_type, l_type, r0_type, r1_type}, batches_runner); - } - } - } - } + ASSERT_NE(0, time_types.size()); + init_types(all_types, key_types, [](T& t) { return is_base_binary_like(t->id()); }); + ASSERT_NE(0, key_types.size()); + init_types(all_types, l_types, [](T& t) { return true; }); + ASSERT_NE(0, l_types.size()); + init_types(all_types, r0_types, [](T& t) { return t->byte_width() > 1; }); + ASSERT_NE(0, r0_types.size()); + init_types(all_types, r1_types, [](T& t) { return t->byte_width() > 1; }); + ASSERT_NE(0, r1_types.size()); + + // sample a limited number of type-combinations to keep the runnning time reasonable + // the scoped-traces below help reproduce a test failure, should it happen + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + ARROW_SCOPED_TRACE("Types seed: ", seed); + std::default_random_engine engine(static_cast(seed)); + std::uniform_int_distribution time_distribution(0, time_types.size() - 1); + std::uniform_int_distribution key_distribution(0, key_types.size() - 1); + std::uniform_int_distribution l_distribution(0, l_types.size() - 1); + std::uniform_int_distribution r0_distribution(0, r0_types.size() - 1); + std::uniform_int_distribution r1_distribution(0, r1_types.size() - 1); + + for (int i = 0; i < 1000; i++) { + auto time_type = time_types[time_distribution(engine)]; + ARROW_SCOPED_TRACE("Time type: ", *time_type); + auto key_type = key_types[key_distribution(engine)]; + ARROW_SCOPED_TRACE("Key type: ", *key_type); + auto l_type = l_types[l_distribution(engine)]; + ARROW_SCOPED_TRACE("Left type: ", *l_type); + auto r0_type = r0_types[r0_distribution(engine)]; + ARROW_SCOPED_TRACE("Right-0 type: ", *r0_type); + auto r1_type = r1_types[r1_distribution(engine)]; + ARROW_SCOPED_TRACE("Right-1 type: ", *r1_type); + + RunTypes({time_type, key_type, l_type, r0_type, r1_type}, batches_runner); } } template @@ -387,11 +487,11 @@ struct BasicTest { // Test three table join BatchesWithSchema l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches; - l_batches = MakeBatchesFromString(l_schema, l_data); - r0_batches = MakeBatchesFromString(r0_schema, r0_data); - r1_batches = MakeBatchesFromString(r1_schema, r1_data); - exp_nokey_batches = MakeBatchesFromString(exp_schema, exp_nokey_data); - exp_batches = MakeBatchesFromString(exp_schema, exp_data); + l_batches = MakeBatchesFromNumString(l_schema, l_data); + r0_batches = MakeBatchesFromNumString(r0_schema, r0_data); + r1_batches = MakeBatchesFromNumString(r1_schema, r1_data); + exp_nokey_batches = MakeBatchesFromNumString(exp_schema, exp_nokey_data); + exp_batches = MakeBatchesFromNumString(exp_schema, exp_data); batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches); } @@ -405,18 +505,18 @@ struct BasicTest { class AsofJoinTest : public testing::Test {}; -#define ASOFJOIN_TEST_SET(name, num) \ - TEST(AsofJoinTest, Test##name##num##_SingleByKey) { \ - Get##name##Test##num().RunSingleByKey(); \ - } \ - TEST(AsofJoinTest, Test##name##num##_DoubleByKey) { \ - Get##name##Test##num().RunDoubleByKey(); \ - } \ - TEST(AsofJoinTest, Test##name##num##_MutateByKey) { \ - Get##name##Test##num().RunMutateByKey(); \ - } \ - TEST(AsofJoinTest, Test##name##num##_MutateNoKey) { \ - Get##name##Test##num().RunMutateNoKey(); \ +#define ASOFJOIN_TEST_SET(name, num) \ + TRACED_TEST(AsofJoinTest, Test##name##num##_SingleByKey) { \ + Get##name##Test##num().RunSingleByKey(); \ + } \ + TRACED_TEST(AsofJoinTest, Test##name##num##_DoubleByKey) { \ + Get##name##Test##num().RunDoubleByKey(); \ + } \ + TRACED_TEST(AsofJoinTest, Test##name##num##_MutateByKey) { \ + Get##name##Test##num().RunMutateByKey(); \ + } \ + TRACED_TEST(AsofJoinTest, Test##name##num##_MutateNoKey) { \ + Get##name##Test##num().RunMutateNoKey(); \ } BasicTest GetBasicTest1() { @@ -605,26 +705,29 @@ BasicTest GetEmptyTest5() { } ASOFJOIN_TEST_SET(Empty, 5) -TEST(AsofJoinTest, TestUnsupportedOntype) { - DoRunInvalidTypeTest( - schema({field("time", utf8()), field("key", int32()), field("l_v0", float64())}), - schema({field("time", utf8()), field("key", int32()), field("r0_v0", float32())})); +TRACED_TEST(AsofJoinTest, TestUnsupportedOntype) { + DoRunInvalidTypeTest(schema({field("time", list(int32())), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", list(int32())), field("key", int32()), + field("r0_v0", float32())})); } -TEST(AsofJoinTest, TestUnsupportedBytype) { - DoRunInvalidTypeTest( - schema({field("time", int64()), field("key", utf8()), field("l_v0", float64())}), - schema({field("time", int64()), field("key", utf8()), field("r0_v0", float32())})); +TRACED_TEST(AsofJoinTest, TestUnsupportedBytype) { + DoRunInvalidTypeTest(schema({field("time", int64()), field("key", list(int32())), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", list(int32())), + field("r0_v0", float32())})); } -TEST(AsofJoinTest, TestUnsupportedDatatype) { - // Utf8 is unsupported +TRACED_TEST(AsofJoinTest, TestUnsupportedDatatype) { + // List is unsupported DoRunInvalidTypeTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), - schema({field("time", int64()), field("key", int32()), field("r0_v0", utf8())})); + schema({field("time", int64()), field("key", int32()), + field("r0_v0", list(int32()))})); } -TEST(AsofJoinTest, TestMissingKeys) { +TRACED_TEST(AsofJoinTest, TestMissingKeys) { DoRunMissingKeysTest( schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}), schema( @@ -636,51 +739,51 @@ TEST(AsofJoinTest, TestMissingKeys) { {field("time", int64()), field("key1", int32()), field("r0_v0", float64())})); } -TEST(AsofJoinTest, TestUnsupportedTolerance) { +TRACED_TEST(AsofJoinTest, TestUnsupportedTolerance) { // Utf8 is unsupported DoRunInvalidToleranceTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } -TEST(AsofJoinTest, TestEmptyByKey) { +TRACED_TEST(AsofJoinTest, TestEmptyByKey) { DoRunEmptyByKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } -TEST(AsofJoinTest, TestMissingOnKey) { +TRACED_TEST(AsofJoinTest, TestMissingOnKey) { DoRunMissingOnKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } -TEST(AsofJoinTest, TestMissingByKey) { +TRACED_TEST(AsofJoinTest, TestMissingByKey) { DoRunMissingByKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } -TEST(AsofJoinTest, TestNestedOnKey) { +TRACED_TEST(AsofJoinTest, TestNestedOnKey) { DoRunNestedOnKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } -TEST(AsofJoinTest, TestNestedByKey) { +TRACED_TEST(AsofJoinTest, TestNestedByKey) { DoRunNestedByKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } -TEST(AsofJoinTest, TestAmbiguousOnKey) { +TRACED_TEST(AsofJoinTest, TestAmbiguousOnKey) { DoRunAmbiguousOnKeyTest( schema({field("time", int64()), field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } -TEST(AsofJoinTest, TestAmbiguousByKey) { +TRACED_TEST(AsofJoinTest, TestAmbiguousByKey) { DoRunAmbiguousByKeyTest( schema({field("time", int64()), field("key", int64()), field("key", int32()), field("l_v0", float64())}), diff --git a/cpp/src/arrow/compute/exec/key_hash.cc b/cpp/src/arrow/compute/exec/key_hash.cc index 3f495bc9e6005..5ff0d4cf1e551 100644 --- a/cpp/src/arrow/compute/exec/key_hash.cc +++ b/cpp/src/arrow/compute/exec/key_hash.cc @@ -432,11 +432,14 @@ void Hashing32::HashMultiColumn(const std::vector& cols, cols[icol].data(1) + first_row * col_width, hashes + first_row, hash_temp); } - } else { - // TODO: add support for 64-bit offsets + } else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) { HashVarLen(ctx->hardware_flags, icol > 0, batch_size_next, cols[icol].offsets() + first_row, cols[icol].data(2), hashes + first_row, hash_temp); + } else { + HashVarLen(ctx->hardware_flags, icol > 0, batch_size_next, + cols[icol].large_offsets() + first_row, cols[icol].data(2), + hashes + first_row, hash_temp); } // Zero hash for nulls @@ -865,10 +868,12 @@ void Hashing64::HashMultiColumn(const std::vector& cols, HashFixed(icol > 0, batch_size_next, col_width, cols[icol].data(1) + first_row * col_width, hashes + first_row); } - } else { - // TODO: add support for 64-bit offsets + } else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) { HashVarLen(icol > 0, batch_size_next, cols[icol].offsets() + first_row, cols[icol].data(2), hashes + first_row); + } else { + HashVarLen(icol > 0, batch_size_next, cols[icol].large_offsets() + first_row, + cols[icol].data(2), hashes + first_row); } // Zero hash for nulls diff --git a/cpp/src/arrow/compute/light_array.h b/cpp/src/arrow/compute/light_array.h index 6a5c205a7a48b..e0adc9266fed1 100644 --- a/cpp/src/arrow/compute/light_array.h +++ b/cpp/src/arrow/compute/light_array.h @@ -135,6 +135,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a varbinary type uint32_t* mutable_offsets() { DCHECK(!metadata_.is_fixed_length); + DCHECK(metadata_.fixed_length == sizeof(uint32_t)); return reinterpret_cast(mutable_data(kFixedLengthBuffer)); } /// \brief Return a read-only version of the offsets buffer @@ -142,8 +143,25 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a varbinary type const uint32_t* offsets() const { DCHECK(!metadata_.is_fixed_length); + DCHECK(metadata_.fixed_length == sizeof(uint32_t)); return reinterpret_cast(data(kFixedLengthBuffer)); } + /// \brief Return a mutable version of the large-offsets buffer + /// + /// Only valid if this is a view into a large varbinary type + uint64_t* mutable_large_offsets() { + DCHECK(!metadata_.is_fixed_length); + DCHECK(metadata_.fixed_length == sizeof(uint64_t)); + return reinterpret_cast(mutable_data(kFixedLengthBuffer)); + } + /// \brief Return a read-only version of the large-offsets buffer + /// + /// Only valid if this is a view into a large varbinary type + const uint64_t* large_offsets() const { + DCHECK(!metadata_.is_fixed_length); + DCHECK(metadata_.fixed_length == sizeof(uint64_t)); + return reinterpret_cast(data(kFixedLengthBuffer)); + } /// \brief Return the type metadata const KeyColumnMetadata& metadata() const { return metadata_; } /// \brief Return the length (in rows) of the array From ca350d97c63f77beb51a4f0cca01df6e42d54f76 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 21 Aug 2022 08:39:32 -0400 Subject: [PATCH 10/26] cleanup --- cpp/src/arrow/compute/exec/asof_join_node_test.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index cff9a6111ad10..b1e503a0e0c2d 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -167,10 +167,6 @@ BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, } new_batches.batches.emplace_back(new_values, batch.length); } - /*std::cerr << "new schema: " << new_batches.schema->ToString() << std::endl; - for (const ExecBatch& new_batch : new_batches.batches) { - std::cerr << "new batch: " << new_batch.ToString() << std::endl; - }*/ return new_batches; } From 717468899d9d34d6607dd82ea7104024a84f8280 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 21 Aug 2022 15:24:34 -0400 Subject: [PATCH 11/26] sanitize --- cpp/src/arrow/compute/exec/asof_join_node.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 2e910d3a3bcb4..5cce574c538de 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -629,23 +629,22 @@ class CompositeReferenceTable { using enable_if_fixed_width_type = enable_if_t::value, R>; template ::BuilderType> - enable_if_fixed_width_type BuilderAppend(Builder& builder, - const std::shared_ptr& source, - row_index_t row) { + enable_if_fixed_width_type BuilderAppend( + Builder& builder, const std::shared_ptr& source, row_index_t row) { using CType = typename TypeTraits::CType; builder.UnsafeAppend(source->template GetValues(1)[row]); + return Status::OK(); } template ::BuilderType> - enable_if_base_binary BuilderAppend(Builder& builder, - const std::shared_ptr& source, - row_index_t row) { + enable_if_base_binary BuilderAppend( + Builder& builder, const std::shared_ptr& source, row_index_t row) { using offset_type = typename Type::offset_type; const uint8_t* data = source->buffers[2]->data(); const offset_type* offsets = source->GetValues(1); const offset_type offset0 = offsets[row]; const offset_type offset1 = offsets[row + 1]; - builder.Append(data + offset0, offset1 - offset0); + return builder.Append(data + offset0, offset1 - offset0); } template ::BuilderType> @@ -658,7 +657,9 @@ class CompositeReferenceTable { for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; if (ref.batch) { - BuilderAppend(builder, ref.batch->column_data(i_col), ref.row); + Status st = + BuilderAppend(builder, ref.batch->column_data(i_col), ref.row); + ARROW_RETURN_NOT_OK(st); } else { builder.UnsafeAppendNull(); } From c3d71e0a37bfda367f4deb22271af89e36f5710f Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 21 Aug 2022 17:08:26 -0400 Subject: [PATCH 12/26] AsofJoin test time limit --- cpp/src/arrow/compute/exec/asof_join_node_test.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index c3966e4d1c407..adda53e477491 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -440,7 +441,8 @@ struct BasicTest { // sample a limited number of type-combinations to keep the runnning time reasonable // the scoped-traces below help reproduce a test failure, should it happen - auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + auto start_time = std::chrono::system_clock::now(); + auto seed = start_time.time_since_epoch().count(); ARROW_SCOPED_TRACE("Types seed: ", seed); std::default_random_engine engine(static_cast(seed)); std::uniform_int_distribution time_distribution(0, time_types.size() - 1); @@ -462,6 +464,14 @@ struct BasicTest { ARROW_SCOPED_TRACE("Right-1 type: ", *r1_type); RunTypes({time_type, key_type, l_type, r0_type, r1_type}, batches_runner); + + auto end_time = std::chrono::system_clock::now(); + std::chrono::duration diff = end_time - start_time; + if (diff.count() > 2) { + std::cerr << "AsofJoin test reached time limit at iteration " << i << std::endl; + // this normally happens on slow CI systems, but is fine + break; + } } } template From 82b5a32515112d2946b8472873f3f52c1fb21c8f Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 23 Aug 2022 13:05:48 -0400 Subject: [PATCH 13/26] AsofJoin support/test for null values, null by-key, out-of-order on-key --- cpp/src/arrow/compute/exec/asof_join_node.cc | 85 ++++++--- .../arrow/compute/exec/asof_join_node_test.cc | 170 +++++++++++++++--- 2 files changed, 204 insertions(+), 51 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 5cce574c538de..cbfdbab409499 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -239,7 +239,8 @@ class InputState { // turned into output record batches. public: - InputState(KeyHasher* key_hasher, const std::shared_ptr& schema, + InputState(bool must_hash, KeyHasher* key_hasher, + const std::shared_ptr& schema, const col_index_t time_col_index, const vec_col_index_t& key_col_index) : queue_(), schema_(schema), @@ -247,7 +248,8 @@ class InputState { key_col_index_(key_col_index), time_type_id_(schema_->fields()[time_col_index_]->type()->id()), key_type_id_(schema_->num_fields()), - key_hasher_(key_hasher) { + key_hasher_(key_hasher), + must_hash_(must_hash) { for (size_t k = 0; k < key_col_index_.size(); k++) { key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); } @@ -295,10 +297,11 @@ class InputState { } ByType GetLatestKey() const { - if (key_hasher_ != NULLPTR) { - return key_hasher_->HashesFor(queue_.UnsyncFront().get())[latest_ref_row_]; + const RecordBatch* batch = queue_.UnsyncFront().get(); + if (must_hash_ || batch->column_data(key_col_index_[0])->GetNullCount() > 0) { + return key_hasher_->HashesFor(batch)[latest_ref_row_]; } - auto data = queue_.UnsyncFront()->column_data(key_col_index_[0]); + auto data = batch->column_data(key_col_index_[0]); switch (key_type_id_[0]) { LATEST_VAL_CASE(INT8, key_value) LATEST_VAL_CASE(INT16, key_value) @@ -345,7 +348,7 @@ class InputState { bool Finished() const { return batches_processed_ == total_batches_; } - bool Advance() { + Result Advance() { // Try advancing to the next row and update latest_ref_row_ // Returns true if able to advance, false if not. bool have_active_batch = @@ -361,6 +364,13 @@ class InputState { if (have_active_batch) DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed } + if (have_active_batch) { + OnType next_time = GetLatestTime(); + if (latest_time_ > next_time) { + return Status::Invalid("AsofJoin does not allow out-of-order on-key values"); + } + latest_time_ = next_time; + } } return have_active_batch; } @@ -369,7 +379,7 @@ class InputState { // latest_time and latest_ref_row to the value that immediately pass the // specified timestamp. // Returns true if updates were made, false if not. - bool AdvanceAndMemoize(OnType ts) { + Result AdvanceAndMemoize(OnType ts) { // Advance the right side row index until we reach the latest right row (for each key) // for the given left timestamp. @@ -377,7 +387,7 @@ class InputState { if (Empty()) return false; // can't advance if empty // Not updated. Try to update and possibly advance. - bool updated = false; + bool advanced, updated = false; do { auto latest_time = GetLatestTime(); // if Advance() returns true, then the latest_ts must also be valid @@ -389,7 +399,8 @@ class InputState { } memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); updated = true; - } while (Advance()); + ARROW_ASSIGN_OR_RAISE(advanced, Advance()); + } while (advanced); return updated; } @@ -443,11 +454,15 @@ class InputState { Type::type time_type_id_; // Type id of the key column std::vector key_type_id_; - // Buffer for key elements + // Hasher for key elements mutable KeyHasher* key_hasher_; + // True if hashing is mandatory + bool must_hash_; // Index of the latest row reference within; if >0 then queue_ cannot be empty // Must be < queue_.front()->num_rows() if queue_ is non-empty row_index_t latest_ref_row_ = 0; + // Time of latest row + OnType latest_time_ = std::numeric_limits::lowest(); // Stores latest known values for the various keys MemoStore memo_; // Mapping of source columns to destination columns @@ -622,6 +637,13 @@ class CompositeReferenceTable { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } + // this should really be a method on ArrayData + static bool IsNull(const std::shared_ptr& source, row_index_t row) { + return ((source->buffers[0] != NULLPTR) + ? !bit_util::GetBit(source->buffers[0]->data(), row + source->offset) + : source->null_count.load() == source->length); + } + template using is_fixed_width_type = std::is_base_of; @@ -629,16 +651,23 @@ class CompositeReferenceTable { using enable_if_fixed_width_type = enable_if_t::value, R>; template ::BuilderType> - enable_if_fixed_width_type BuilderAppend( + enable_if_fixed_width_type static BuilderAppend( Builder& builder, const std::shared_ptr& source, row_index_t row) { + if (IsNull(source, row)) { + builder.UnsafeAppendNull(); + return Status::OK(); + } using CType = typename TypeTraits::CType; builder.UnsafeAppend(source->template GetValues(1)[row]); return Status::OK(); } template ::BuilderType> - enable_if_base_binary BuilderAppend( + enable_if_base_binary static BuilderAppend( Builder& builder, const std::shared_ptr& source, row_index_t row) { + if (IsNull(source, row)) { + return builder.AppendNull(); + } using offset_type = typename Type::offset_type; const uint8_t* data = source->buffers[2]->data(); const offset_type* offsets = source->GetValues(1); @@ -672,12 +701,14 @@ class CompositeReferenceTable { class AsofJoinNode : public ExecNode { // Advances the RHS as far as possible to be up to date for the current LHS timestamp - bool UpdateRhs() { + Result UpdateRhs() { auto& lhs = *state_.at(0); auto lhs_latest_time = lhs.GetLatestTime(); bool any_updated = false; - for (size_t i = 1; i < state_.size(); ++i) - any_updated |= state_[i]->AdvanceAndMemoize(lhs_latest_time); + for (size_t i = 1; i < state_.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(bool advanced, state_[i]->AdvanceAndMemoize(lhs_latest_time)); + any_updated |= advanced; + } return any_updated; } @@ -713,7 +744,7 @@ class AsofJoinNode : public ExecNode { if (lhs.Finished() || lhs.Empty()) break; // Advance each of the RHS as far as possible to be up to date for the LHS timestamp - bool any_rhs_advanced = UpdateRhs(); + ARROW_ASSIGN_OR_RAISE(bool any_rhs_advanced, UpdateRhs()); // If we have received enough inputs to produce the next output batch // (decided by IsUpToDateWithLhsRow), we will perform the join and @@ -722,7 +753,8 @@ class AsofJoinNode : public ExecNode { // input batches that are no longer needed are removed to free up memory. if (IsUpToDateWithLhsRow()) { dst.Emplace(state_, tolerance_); - if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch + ARROW_ASSIGN_OR_RAISE(bool advanced, lhs.Advance()); + if (!advanced) break; // if we can't advance LHS, we're done for this batch } else { if (!any_rhs_advanced) break; // need to wait for new data } @@ -795,13 +827,13 @@ class AsofJoinNode : public ExecNode { const std::vector& indices_of_by_key, OnType tolerance, std::shared_ptr output_schema); - Status InternalInit(std::vector> key_hashers) { + Status InternalInit(bool must_hash, + std::vector> key_hashers) { key_hashers_.swap(key_hashers); - bool has_hashers = key_hashers_.size() > 0; auto inputs = this->inputs(); for (size_t i = 0; i < inputs.size(); ++i) { state_.push_back(::arrow::internal::make_unique( - has_hashers ? key_hashers_[i].get() : NULLPTR, inputs[i]->output_schema(), + must_hash, key_hashers_[i].get(), inputs[i]->output_schema(), indices_of_on_key_[i], indices_of_by_key_[i])); } @@ -1003,14 +1035,13 @@ class AsofJoinNode : public ExecNode { auto node_indices_of_by_key = checked_cast(node)->indices_of_by_key(); auto single_key_field = inputs[0]->output_schema()->field(indices_of_by_key[0][0]); std::vector> key_hashers; - if (n_by > 1 || !is_primitive(single_key_field->type()->id())) { - for (size_t i = 0; i < n_input; i++) { - key_hashers.push_back( - ::arrow::internal::make_unique(node_indices_of_by_key[i])); - RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); - } + bool must_hash = n_by > 1 || !is_primitive(single_key_field->type()->id()); + for (size_t i = 0; i < n_input; i++) { + key_hashers.push_back( + ::arrow::internal::make_unique(node_indices_of_by_key[i])); + RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); } - RETURN_NOT_OK(node->InternalInit(std::move(key_hashers))); + RETURN_NOT_OK(node->InternalInit(must_hash, std::move(key_hashers))); return node; } diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index adda53e477491..28ea9b2387f49 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -97,6 +97,14 @@ BatchesWithSchema MakeBatchesFromNumString( return batches; } +void BuildNullArray(std::shared_ptr& empty, const std::shared_ptr& type, + int64_t length) { + ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool())); + ASSERT_OK(builder->Reserve(length)); + ASSERT_OK(builder->AppendNulls(length)); + ASSERT_OK(builder->Finish(&empty)); +} + void BuildZeroPrimitiveArray(std::shared_ptr& empty, const std::shared_ptr& type, int64_t length) { ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool())); @@ -118,7 +126,8 @@ void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { // mutates by copying from_key into to_key and changing from_key to zero BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, - std::string to_key, bool replace_key = false) { + std::string to_key, bool replace_key = false, + bool null_key = false) { int from_index = batches.schema->GetFieldIndex(from_key); int n_fields = batches.schema->num_fields(); auto fields = batches.schema->fields(); @@ -133,7 +142,11 @@ BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, const Datum& value = batch.values[i]; if (i == from_index) { auto type = fields[i]->type(); - if (is_primitive(type->id())) { + if (null_key) { + std::shared_ptr empty; + BuildNullArray(empty, type, batch.length); + new_values.push_back(empty); + } else if (is_primitive(type->id())) { std::shared_ptr empty; BuildZeroPrimitiveArray(empty, type, batch.length); new_values.push_back(empty); @@ -224,13 +237,11 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) -void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema, - const AsofJoinNodeOptions& join_options, - const std::string& expected_error_str) { - BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {R"([])"}); - BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {R"([])"}); - +void DoInvalidPlanTest(const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str, + bool then_run_plan = false) { ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); @@ -240,8 +251,34 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_error_str), - join.AddToPlan(plan.get())); + if (then_run_plan) { + AsyncGenerator> sink_gen; + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr(expected_error_str), + StartAndCollect(plan.get(), sink_gen)); + } else { + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_error_str), + join.AddToPlan(plan.get())); + } +} + +void DoRunInvalidPlanTest(const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str); +} + +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {R"([])"}); + BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {R"([])"}); + + return DoRunInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str); } void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, @@ -307,6 +344,51 @@ void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); } +std::string GetJsonStringWithOrder(int n_rows, int n_cols, bool unordered) { + std::stringstream s; + s << '['; + for (int i = 0; i < n_rows; i++) { + if (i > 0) { + s << ", "; + } + s << '['; + for (int j = 0; j < n_cols; j++) { + if (j > 0) { + s << ", " << j; + } else { + s << (i ^ unordered); + } + } + s << ']'; + } + s << ']'; + return s.str(); +} + +void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, + const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + ASSERT_TRUE(l_unordered || r_unordered); + int n_rows = 5; + std::string l_str = GetJsonStringWithOrder(n_rows, l_schema->num_fields(), l_unordered); + std::string r_str = GetJsonStringWithOrder(n_rows, r_schema->num_fields(), r_unordered); + BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {l_str}); + BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {r_str}); + + return DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str, + /*then_run_plan=*/true); +} + +void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, + const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema, + AsofJoinNodeOptions("time", "key", 1000), + "out-of-order on-key values"); +} + struct BasicTestTypes { std::shared_ptr time, key, l_val, r0_val, r1_val; }; @@ -388,7 +470,23 @@ struct BasicTest { l_batches = MutateByKey(l_batches, "key", "key2", true); r0_batches = MutateByKey(r0_batches, "key", "key2", true); r1_batches = MutateByKey(r1_batches, "key", "key2", true); - exp_batches = MutateByKey(exp_batches, "key", "key2", true); + exp_nokey_batches = MutateByKey(exp_nokey_batches, "key", "key2", true); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, "time", "key2", + tolerance); + }); + } + void RunMutateNullKey(std::vector> time_types = {}, + std::vector> key_types = {}, + std::vector> l_types = {}, + std::vector> r0_types = {}, + std::vector> r1_types = {}) { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_batches) { + l_batches = MutateByKey(l_batches, "key", "key2", true, true); + r0_batches = MutateByKey(r0_batches, "key", "key2", true, true); + r1_batches = MutateByKey(r1_batches, "key", "key2", true, true); + exp_nokey_batches = MutateByKey(exp_nokey_batches, "key", "key2", true, true); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, "time", "key2", tolerance); }); @@ -512,18 +610,21 @@ struct BasicTest { class AsofJoinTest : public testing::Test {}; -#define ASOFJOIN_TEST_SET(name, num) \ - TRACED_TEST(AsofJoinTest, Test##name##num##_SingleByKey) { \ - Get##name##Test##num().RunSingleByKey(); \ - } \ - TRACED_TEST(AsofJoinTest, Test##name##num##_DoubleByKey) { \ - Get##name##Test##num().RunDoubleByKey(); \ - } \ - TRACED_TEST(AsofJoinTest, Test##name##num##_MutateByKey) { \ - Get##name##Test##num().RunMutateByKey(); \ - } \ - TRACED_TEST(AsofJoinTest, Test##name##num##_MutateNoKey) { \ - Get##name##Test##num().RunMutateNoKey(); \ +#define ASOFJOIN_TEST_SET(name, num) \ + TRACED_TEST(AsofJoinTest, Test##name##num##_SingleByKey) { \ + Get##name##Test##num().RunSingleByKey(); \ + } \ + TRACED_TEST(AsofJoinTest, Test##name##num##_DoubleByKey) { \ + Get##name##Test##num().RunDoubleByKey(); \ + } \ + TRACED_TEST(AsofJoinTest, Test##name##num##_MutateByKey) { \ + Get##name##Test##num().RunMutateByKey(); \ + } \ + TRACED_TEST(AsofJoinTest, Test##name##num##_MutateNoKey) { \ + Get##name##Test##num().RunMutateNoKey(); \ + } \ + TRACED_TEST(AsofJoinTest, Test##name##num##_MutateNullKey) { \ + Get##name##Test##num().RunMutateNullKey(); \ } BasicTest GetBasicTest1() { @@ -797,5 +898,26 @@ TRACED_TEST(AsofJoinTest, TestAmbiguousByKey) { schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } +TRACED_TEST(AsofJoinTest, TestLeftUnorderedOnKey) { + DoRunUnorderedPlanTest( + /*l_unordered=*/true, /*r_unordered=*/false, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TRACED_TEST(AsofJoinTest, TestRightUnorderedOnKey) { + DoRunUnorderedPlanTest( + /*l_unordered=*/false, /*r_unordered=*/true, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + +TRACED_TEST(AsofJoinTest, TestUnorderedOnKey) { + DoRunUnorderedPlanTest( + /*l_unordered=*/true, /*r_unordered=*/true, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + } // namespace compute } // namespace arrow From 0a9a6b66d846ea0152f36c47010af0d81f5be7f6 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 24 Aug 2022 09:33:03 -0400 Subject: [PATCH 14/26] AsofJoin fixes for nullable and out-of-order by-key --- cpp/src/arrow/compute/exec/asof_join_node.cc | 40 ++++++++++++------- .../arrow/compute/exec/asof_join_node_test.cc | 37 ++++++++++++++--- cpp/src/arrow/compute/exec/options.h | 37 ++++++++++++----- 3 files changed, 83 insertions(+), 31 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index cbfdbab409499..8dd1fb1dac758 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -239,7 +239,7 @@ class InputState { // turned into output record batches. public: - InputState(bool must_hash, KeyHasher* key_hasher, + InputState(bool must_hash, bool nullable_by_key, KeyHasher* key_hasher, const std::shared_ptr& schema, const col_index_t time_col_index, const vec_col_index_t& key_col_index) : queue_(), @@ -249,7 +249,8 @@ class InputState { time_type_id_(schema_->fields()[time_col_index_]->type()->id()), key_type_id_(schema_->num_fields()), key_hasher_(key_hasher), - must_hash_(must_hash) { + must_hash_(must_hash), + nullable_by_key_(nullable_by_key) { for (size_t k = 0; k < key_col_index_.size(); k++) { key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); } @@ -298,7 +299,7 @@ class InputState { ByType GetLatestKey() const { const RecordBatch* batch = queue_.UnsyncFront().get(); - if (must_hash_ || batch->column_data(key_col_index_[0])->GetNullCount() > 0) { + if (must_hash_) { return key_hasher_->HashesFor(batch)[latest_ref_row_]; } auto data = batch->column_data(key_col_index_[0]); @@ -355,6 +356,11 @@ class InputState { (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); if (have_active_batch) { + OnType next_time = GetLatestTime(); + if (latest_time_ > next_time) { + return Status::Invalid("AsofJoin does not allow out-of-order on-key values"); + } + latest_time_ = next_time; // If we have an active batch if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. @@ -364,13 +370,6 @@ class InputState { if (have_active_batch) DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed } - if (have_active_batch) { - OnType next_time = GetLatestTime(); - if (latest_time_ > next_time) { - return Status::Invalid("AsofJoin does not allow out-of-order on-key values"); - } - latest_time_ = next_time; - } } return have_active_batch; } @@ -404,12 +403,16 @@ class InputState { return updated; } - void Push(const std::shared_ptr& rb) { + Status Push(const std::shared_ptr& rb) { + if (!nullable_by_key_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { + return Status::Invalid("AsofJoin does not allow unexpected null by-key values"); + } if (rb->num_rows() > 0) { queue_.Push(rb); } else { ++batches_processed_; // don't enqueue empty batches, just record as processed } + return Status::OK(); } util::optional GetMemoEntryForKey(ByType key) { @@ -458,6 +461,8 @@ class InputState { mutable KeyHasher* key_hasher_; // True if hashing is mandatory bool must_hash_; + // True if null by-key values are expected + bool nullable_by_key_; // Index of the latest row reference within; if >0 then queue_ cannot be empty // Must be < queue_.front()->num_rows() if queue_ is non-empty row_index_t latest_ref_row_ = 0; @@ -827,13 +832,13 @@ class AsofJoinNode : public ExecNode { const std::vector& indices_of_by_key, OnType tolerance, std::shared_ptr output_schema); - Status InternalInit(bool must_hash, + Status InternalInit(bool must_hash, bool nullable_by_key, std::vector> key_hashers) { key_hashers_.swap(key_hashers); auto inputs = this->inputs(); for (size_t i = 0; i < inputs.size(); ++i) { state_.push_back(::arrow::internal::make_unique( - must_hash, key_hashers_[i].get(), inputs[i]->output_schema(), + must_hash, nullable_by_key, key_hashers_[i].get(), inputs[i]->output_schema(), indices_of_on_key_[i], indices_of_by_key_[i])); } @@ -1036,12 +1041,13 @@ class AsofJoinNode : public ExecNode { auto single_key_field = inputs[0]->output_schema()->field(indices_of_by_key[0][0]); std::vector> key_hashers; bool must_hash = n_by > 1 || !is_primitive(single_key_field->type()->id()); + bool nullable_by_key = join_options.nullable_by_key; for (size_t i = 0; i < n_input; i++) { key_hashers.push_back( ::arrow::internal::make_unique(node_indices_of_by_key[i])); RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); } - RETURN_NOT_OK(node->InternalInit(must_hash, std::move(key_hashers))); + RETURN_NOT_OK(node->InternalInit(must_hash, nullable_by_key, std::move(key_hashers))); return node; } @@ -1054,7 +1060,11 @@ class AsofJoinNode : public ExecNode { // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); - state_.at(k)->Push(rb); + Status st = state_.at(k)->Push(rb); + if (!st.ok()) { + ErrorReceived(input, st); + return; + } process_.Push(true); } void ErrorReceived(ExecNode* input, Status error) override { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 28ea9b2387f49..2370923539c15 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -344,7 +344,7 @@ void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); } -std::string GetJsonStringWithOrder(int n_rows, int n_cols, bool unordered) { +std::string GetJsonString(int n_rows, int n_cols, bool unordered = false) { std::stringstream s; s << '['; for (int i = 0; i < n_rows; i++) { @@ -355,8 +355,10 @@ std::string GetJsonStringWithOrder(int n_rows, int n_cols, bool unordered) { for (int j = 0; j < n_cols; j++) { if (j > 0) { s << ", " << j; - } else { + } else if (j < 2) { s << (i ^ unordered); + } else { + s << i; } } s << ']'; @@ -372,8 +374,8 @@ void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, const std::string& expected_error_str) { ASSERT_TRUE(l_unordered || r_unordered); int n_rows = 5; - std::string l_str = GetJsonStringWithOrder(n_rows, l_schema->num_fields(), l_unordered); - std::string r_str = GetJsonStringWithOrder(n_rows, r_schema->num_fields(), r_unordered); + std::string l_str = GetJsonString(n_rows, l_schema->num_fields(), l_unordered); + std::string r_str = GetJsonString(n_rows, r_schema->num_fields(), r_unordered); BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {l_str}); BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {r_str}); @@ -389,6 +391,22 @@ void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, "out-of-order on-key values"); } +void DoRunNullByKeyPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + AsofJoinNodeOptions join_options{"time", "key2", 1000}; + std::string expected_error_str = "unexpected null by-key values"; + int n_rows = 5; + std::string l_str = GetJsonString(n_rows, l_schema->num_fields()); + std::string r_str = GetJsonString(n_rows, r_schema->num_fields()); + BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {l_str}); + BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {r_str}); + l_batches = MutateByKey(l_batches, "key", "key2", true, true); + r_batches = MutateByKey(r_batches, "key", "key2", true, true); + + return DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str, + /*then_run_plan=*/true); +} + struct BasicTestTypes { std::shared_ptr time, key, l_val, r0_val, r1_val; }; @@ -487,8 +505,9 @@ struct BasicTest { r0_batches = MutateByKey(r0_batches, "key", "key2", true, true); r1_batches = MutateByKey(r1_batches, "key", "key2", true, true); exp_nokey_batches = MutateByKey(exp_nokey_batches, "key", "key2", true, true); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, "time", "key2", - tolerance); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, + AsofJoinNodeOptions("time", "key2", tolerance, + /*nullable_by_key=*/true)); }); } template @@ -919,5 +938,11 @@ TRACED_TEST(AsofJoinTest, TestUnorderedOnKey) { schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); } +TRACED_TEST(AsofJoinTest, TestNullByKey) { + DoRunNullByKeyPlanTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 232ff3f5180b6..75aadacbbd02e 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -397,34 +397,51 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, const FieldRef& by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(), tolerance(tolerance) { + AsofJoinNodeOptions(FieldRef on_key, const FieldRef& by_key, int64_t tolerance, + bool nullable_by_key = false) + : on_key(std::move(on_key)), + by_key(), + tolerance(tolerance), + nullable_by_key(nullable_by_key) { this->by_key.push_back(std::move(by_key)); } - AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} + AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance, + bool nullable_by_key = false) + : on_key(std::move(on_key)), + by_key(by_key), + tolerance(tolerance), + nullable_by_key(nullable_by_key) {} // resolves ambiguity between previous constructors when initializer list is given AsofJoinNodeOptions(FieldRef on_key, std::initializer_list by_key, - int64_t tolerance) - : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} + int64_t tolerance, bool nullable_by_key = false) + : on_key(std::move(on_key)), + by_key(by_key), + tolerance(tolerance), + nullable_by_key(nullable_by_key) {} - /// \brief "on" key for the join. Each + /// \brief "on" key for the join. /// /// All inputs tables must be sorted by the "on" key. Must be a single field of a common /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff /// left_on - tolerance <= right_on <= left_on. - /// Currently, the "on" key must be of an integer or timestamp type + /// Currently, the "on" key must be of an integer, date, or timestamp type. FieldRef on_key; /// \brief "by" key for the join. /// /// All input tables must have the "by" key. Exact equality /// is used for the "by" key. - /// Currently, the "by" key must be of an integer or timestamp type + /// Currently, the "by" key must be of an integer, date, timestamp, or base-binary type std::vector by_key; - /// Tolerance for inexact "on" key matching. Must be non-negative. + /// \brief Tolerance for inexact "on" key matching. Must be non-negative. + /// + /// The tolerance is interpreted in the same units as the "on" key. int64_t tolerance; + /// \brief Whether the "by" key is nullable. + /// + /// Set to true if the "by" key is expected to take null values. + bool nullable_by_key; }; /// \brief Make a node which select top_k/bottom_k rows passed through it From 7e7398c8342638943d0a38fd5c7eb8de54c47171 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 25 Aug 2022 05:34:19 -0400 Subject: [PATCH 15/26] requested fixes --- cpp/src/arrow/compute/exec/asof_join_node.cc | 189 ++++++++++-------- .../arrow/compute/exec/asof_join_node_test.cc | 50 ++--- cpp/src/arrow/compute/light_array.h | 8 +- 3 files changed, 133 insertions(+), 114 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 8dd1fb1dac758..215753450cd18 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -60,7 +60,6 @@ typedef uint64_t HashType; #define MAX_JOIN_TABLES 64 typedef uint64_t row_index_t; typedef int col_index_t; -typedef std::vector vec_col_index_t; // normalize the value to 64-bits while preserving ordering of values template ::value, bool> = true> @@ -179,7 +178,7 @@ class KeyHasher { static constexpr int kMiniBatchLength = util::MiniBatch::kMiniBatchLength; public: - explicit KeyHasher(const vec_col_index_t& indices) + explicit KeyHasher(const std::vector& indices) : indices_(indices), metadata_(indices.size()), batch_(NULLPTR), @@ -224,7 +223,7 @@ class KeyHasher { } private: - vec_col_index_t indices_; + std::vector indices_; std::vector metadata_; const RecordBatch* batch_; std::vector hashes_; @@ -241,7 +240,8 @@ class InputState { public: InputState(bool must_hash, bool nullable_by_key, KeyHasher* key_hasher, const std::shared_ptr& schema, - const col_index_t time_col_index, const vec_col_index_t& key_col_index) + const col_index_t time_col_index, + const std::vector& key_col_index) : queue_(), schema_(schema), time_col_index_(time_col_index), @@ -452,7 +452,7 @@ class InputState { // Index of the time col col_index_t time_col_index_; // Index of the key col - vec_col_index_t key_col_index_; + std::vector key_col_index_; // Type id of the time column Type::type time_type_id_; // Type id of the key column @@ -828,9 +828,9 @@ class AsofJoinNode : public ExecNode { public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const vec_col_index_t& indices_of_on_key, - const std::vector& indices_of_by_key, OnType tolerance, - std::shared_ptr output_schema); + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key, + OnType tolerance, std::shared_ptr output_schema); Status InternalInit(bool must_hash, bool nullable_by_key, std::vector> key_hashers) { @@ -854,12 +854,91 @@ class AsofJoinNode : public ExecNode { process_thread_.join(); } - const vec_col_index_t& indices_of_on_key() { return indices_of_on_key_; } - const std::vector& indices_of_by_key() { return indices_of_by_key_; } + const std::vector& indices_of_on_key() { return indices_of_on_key_; } + const std::vector>& indices_of_by_key() { + return indices_of_by_key_; + } + + static Status is_valid_on_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", + field->type()->ToString()); + } + } + + static Status is_valid_by_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", + field->type()->ToString()); + } + } + + static Status is_valid_data_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::FLOAT: + case Type::DOUBLE: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for data field ", field->name(), " : ", + field->type()->ToString()); + } + } static arrow::Result> MakeOutputSchema( - const std::vector& inputs, const vec_col_index_t& indices_of_on_key, - const std::vector& indices_of_by_key) { + const std::vector& inputs, + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key) { std::vector> fields; size_t n_by = indices_of_by_key[0].size(); @@ -901,84 +980,19 @@ class AsofJoinNode : public ExecNode { for (int i = 0; i < input_schema->num_fields(); ++i) { const auto field = input_schema->field(i); if (i == on_field_ix) { - switch (field->type()->id()) { - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - break; - default: - return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", - field->type()->ToString()); - } + ARROW_RETURN_NOT_OK(is_valid_on_field(field)); // Only add on field from the left table if (j == 0) { fields.push_back(field); } } else if (std_has(by_field_ix, i)) { - switch (field->type()->id()) { - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - case Type::STRING: - case Type::LARGE_STRING: - case Type::BINARY: - case Type::LARGE_BINARY: - break; - default: - return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", - field->type()->ToString()); - } + ARROW_RETURN_NOT_OK(is_valid_by_field(field)); // Only add by field from the left table if (j == 0) { fields.push_back(field); } } else { - switch (field->type()->id()) { - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::FLOAT: - case Type::DOUBLE: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - case Type::STRING: - case Type::LARGE_STRING: - case Type::BINARY: - case Type::LARGE_BINARY: - break; - default: - return Status::Invalid("Unsupported type for field ", field->name(), " : ", - field->type()->ToString()); - } - + ARROW_RETURN_NOT_OK(is_valid_data_field(field)); fields.push_back(field); } } @@ -988,7 +1002,7 @@ class AsofJoinNode : public ExecNode { static inline Result FindColIndex(const Schema& schema, const FieldRef& field_ref, - const util::string_view& key_kind) { + util::string_view key_kind) { auto match_res = field_ref.FindOne(schema); if (!match_res.ok()) { return Status::Invalid("Bad join key on table : ", match_res.status().message()); @@ -1016,8 +1030,9 @@ class AsofJoinNode : public ExecNode { size_t n_input = inputs.size(), n_by = join_options.by_key.size(); std::vector input_labels(n_input); - vec_col_index_t indices_of_on_key(n_input); - std::vector indices_of_by_key(n_input, vec_col_index_t(n_by)); + std::vector indices_of_on_key(n_input); + std::vector> indices_of_by_key( + n_input, std::vector(n_by)); for (size_t i = 0; i < n_input; ++i) { input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i); const Schema& input_schema = *inputs[i]->output_schema(); @@ -1100,8 +1115,8 @@ class AsofJoinNode : public ExecNode { private: arrow::Future<> finished_; std::vector> key_hashers_; - vec_col_index_t indices_of_on_key_; - std::vector indices_of_by_key_; + std::vector indices_of_on_key_; + std::vector> indices_of_by_key_; // InputStates // Each input state correponds to an input table std::vector> state_; @@ -1120,8 +1135,8 @@ class AsofJoinNode : public ExecNode { AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const vec_col_index_t& indices_of_on_key, - const std::vector& indices_of_by_key, + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key, OnType tolerance, std::shared_ptr output_schema) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 2370923539c15..675cf9592e1df 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -18,7 +18,6 @@ #include #include -#include #include #include #include @@ -64,7 +63,7 @@ bool is_temporal_primitive(Type::type type_id) { } } -BatchesWithSchema MakeBatchesFromNumString( +Result MakeBatchesFromNumString( const std::shared_ptr& schema, const std::vector& json_strings, int multiplicity = 1) { FieldVector num_fields; @@ -85,8 +84,9 @@ BatchesWithSchema MakeBatchesFromNumString( auto type = schema->field(i)->type(); if (is_base_binary_like(type->id())) { // casting to string first enables casting to binary - Datum as_string = Cast(num_batch.values[i], utf8()).ValueOrDie(); - values.push_back(Cast(as_string, type).ValueOrDie()); + ARROW_ASSIGN_OR_RAISE(Datum as_string, Cast(num_batch.values[i], utf8())); + ARROW_ASSIGN_OR_RAISE(Datum as_type, Cast(as_string, type)); + values.push_back(as_type); } else { values.push_back(num_batch.values[i]); } @@ -241,7 +241,7 @@ void DoInvalidPlanTest(const BatchesWithSchema& l_batches, const BatchesWithSchema& r_batches, const AsofJoinNodeOptions& join_options, const std::string& expected_error_str, - bool then_run_plan = false) { + bool fail_on_plan_creation = false) { ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); @@ -251,7 +251,7 @@ void DoInvalidPlanTest(const BatchesWithSchema& l_batches, join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); - if (then_run_plan) { + if (fail_on_plan_creation) { AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) .AddToPlan(plan.get())); @@ -275,8 +275,8 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema, const AsofJoinNodeOptions& join_options, const std::string& expected_error_str) { - BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {R"([])"}); - BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {R"([])"}); + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {R"([])"})); + ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {R"([])"})); return DoRunInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str); } @@ -344,7 +344,10 @@ void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); } -std::string GetJsonString(int n_rows, int n_cols, bool unordered = false) { +// Gets a batch for testing as a Json string +// The batch will have n_rows rows n_cols columns, the first column being the on-field +// If unordered is true then the first column will be out-of-order +std::string GetTestBatchAsJsonString(int n_rows, int n_cols, bool unordered = false) { std::stringstream s; s << '['; for (int i = 0; i < n_rows; i++) { @@ -374,10 +377,10 @@ void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, const std::string& expected_error_str) { ASSERT_TRUE(l_unordered || r_unordered); int n_rows = 5; - std::string l_str = GetJsonString(n_rows, l_schema->num_fields(), l_unordered); - std::string r_str = GetJsonString(n_rows, r_schema->num_fields(), r_unordered); - BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {l_str}); - BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {r_str}); + auto l_str = GetTestBatchAsJsonString(n_rows, l_schema->num_fields(), l_unordered); + auto r_str = GetTestBatchAsJsonString(n_rows, r_schema->num_fields(), r_unordered); + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {l_str})); + ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {r_str})); return DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str, /*then_run_plan=*/true); @@ -396,10 +399,10 @@ void DoRunNullByKeyPlanTest(const std::shared_ptr& l_schema, AsofJoinNodeOptions join_options{"time", "key2", 1000}; std::string expected_error_str = "unexpected null by-key values"; int n_rows = 5; - std::string l_str = GetJsonString(n_rows, l_schema->num_fields()); - std::string r_str = GetJsonString(n_rows, r_schema->num_fields()); - BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {l_str}); - BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {r_str}); + auto l_str = GetTestBatchAsJsonString(n_rows, l_schema->num_fields()); + auto r_str = GetTestBatchAsJsonString(n_rows, r_schema->num_fields()); + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {l_str})); + ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {r_str})); l_batches = MutateByKey(l_batches, "key", "key2", true, true); r_batches = MutateByKey(r_batches, "key", "key2", true, true); @@ -610,12 +613,13 @@ struct BasicTest { }); // Test three table join - BatchesWithSchema l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches; - l_batches = MakeBatchesFromNumString(l_schema, l_data); - r0_batches = MakeBatchesFromNumString(r0_schema, r0_data); - r1_batches = MakeBatchesFromNumString(r1_schema, r1_data); - exp_nokey_batches = MakeBatchesFromNumString(exp_schema, exp_nokey_data); - exp_batches = MakeBatchesFromNumString(exp_schema, exp_data); + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, l_data)); + ASSERT_OK_AND_ASSIGN(auto r0_batches, MakeBatchesFromNumString(r0_schema, r0_data)); + ASSERT_OK_AND_ASSIGN(auto r1_batches, MakeBatchesFromNumString(r1_schema, r1_data)); + ASSERT_OK_AND_ASSIGN(auto exp_nokey_batches, + MakeBatchesFromNumString(exp_schema, exp_nokey_data)); + ASSERT_OK_AND_ASSIGN(auto exp_batches, + MakeBatchesFromNumString(exp_schema, exp_data)); batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches); } diff --git a/cpp/src/arrow/compute/light_array.h b/cpp/src/arrow/compute/light_array.h index e0adc9266fed1..389b63cca4143 100644 --- a/cpp/src/arrow/compute/light_array.h +++ b/cpp/src/arrow/compute/light_array.h @@ -135,7 +135,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a varbinary type uint32_t* mutable_offsets() { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint32_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint32_t)); return reinterpret_cast(mutable_data(kFixedLengthBuffer)); } /// \brief Return a read-only version of the offsets buffer @@ -143,7 +143,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a varbinary type const uint32_t* offsets() const { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint32_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint32_t)); return reinterpret_cast(data(kFixedLengthBuffer)); } /// \brief Return a mutable version of the large-offsets buffer @@ -151,7 +151,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a large varbinary type uint64_t* mutable_large_offsets() { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint64_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint64_t)); return reinterpret_cast(mutable_data(kFixedLengthBuffer)); } /// \brief Return a read-only version of the large-offsets buffer @@ -159,7 +159,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a large varbinary type const uint64_t* large_offsets() const { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint64_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint64_t)); return reinterpret_cast(data(kFixedLengthBuffer)); } /// \brief Return the type metadata From faf79498929b3bce9ee02b8bc1f991f3e462865d Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 25 Aug 2022 08:11:00 -0400 Subject: [PATCH 16/26] more requested fixes --- cpp/src/arrow/compute/exec/asof_join_node.cc | 12 ++--- .../arrow/compute/exec/asof_join_node_test.cc | 47 ++++++++++--------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 215753450cd18..7ec717795ec9c 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -832,9 +832,9 @@ class AsofJoinNode : public ExecNode { const std::vector>& indices_of_by_key, OnType tolerance, std::shared_ptr output_schema); - Status InternalInit(bool must_hash, bool nullable_by_key, - std::vector> key_hashers) { - key_hashers_.swap(key_hashers); + void InternalInit(bool must_hash, bool nullable_by_key, + std::vector> key_hashers) { + key_hashers_ = std::move(key_hashers); auto inputs = this->inputs(); for (size_t i = 0; i < inputs.size(); ++i) { state_.push_back(::arrow::internal::make_unique( @@ -845,8 +845,6 @@ class AsofJoinNode : public ExecNode { col_index_t dst_offset = 0; for (auto& state : state_) dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); - - return Status::OK(); } virtual ~AsofJoinNode() { @@ -1007,7 +1005,7 @@ class AsofJoinNode : public ExecNode { if (!match_res.ok()) { return Status::Invalid("Bad join key on table : ", match_res.status().message()); } - auto match = match_res.ValueOrDie(); + ARROW_ASSIGN_OR_RAISE(auto match, match_res); if (match.indices().size() != 1) { return Status::Invalid("AsOfJoinNode does not support a nested ", to_string(key_kind), "-key ", field_ref.ToString()); @@ -1062,7 +1060,7 @@ class AsofJoinNode : public ExecNode { ::arrow::internal::make_unique(node_indices_of_by_key[i])); RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); } - RETURN_NOT_OK(node->InternalInit(must_hash, nullable_by_key, std::move(key_hashers))); + node->InternalInit(must_hash, nullable_by_key, std::move(key_hashers)); return node; } diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 675cf9592e1df..aaf5ddd151c88 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -125,17 +125,17 @@ void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { } // mutates by copying from_key into to_key and changing from_key to zero -BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, - std::string to_key, bool replace_key = false, - bool null_key = false) { +Result MutateByKey(BatchesWithSchema& batches, std::string from_key, + std::string to_key, bool replace_key = false, + bool null_key = false) { int from_index = batches.schema->GetFieldIndex(from_key); int n_fields = batches.schema->num_fields(); auto fields = batches.schema->fields(); BatchesWithSchema new_batches; auto new_field = batches.schema->field(from_index)->WithName(to_key); - new_batches.schema = (replace_key ? batches.schema->SetField(from_index, new_field) - : batches.schema->AddField(from_index, new_field)) - .ValueOrDie(); + ARROW_ASSIGN_OR_RAISE(new_batches.schema, + replace_key ? batches.schema->SetField(from_index, new_field) + : batches.schema->AddField(from_index, new_field)); for (const ExecBatch& batch : batches.batches) { std::vector new_values; for (int i = 0; i < n_fields; i++) { @@ -171,7 +171,8 @@ BatchesWithSchema MutateByKey(BatchesWithSchema& batches, std::string from_key, } new_values.push_back(empty); } else { - new_values.push_back(Subtract(value, value).ValueOrDie()); + ARROW_ASSIGN_OR_RAISE(auto sub, Subtract(value, value)); + new_values.push_back(sub); } if (replace_key) { continue; @@ -403,8 +404,8 @@ void DoRunNullByKeyPlanTest(const std::shared_ptr& l_schema, auto r_str = GetTestBatchAsJsonString(n_rows, r_schema->num_fields()); ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {l_str})); ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {r_str})); - l_batches = MutateByKey(l_batches, "key", "key2", true, true); - r_batches = MutateByKey(r_batches, "key", "key2", true, true); + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(r_batches, MutateByKey(r_batches, "key", "key2", true, true)); return DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str, /*then_run_plan=*/true); @@ -472,10 +473,10 @@ struct BasicTest { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, B exp_batches) { - l_batches = MutateByKey(l_batches, "key", "key2"); - r0_batches = MutateByKey(r0_batches, "key", "key2"); - r1_batches = MutateByKey(r1_batches, "key", "key2"); - exp_batches = MutateByKey(exp_batches, "key", "key2"); + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(exp_batches, MutateByKey(exp_batches, "key", "key2")); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", {"key", "key2"}, tolerance); }); @@ -488,10 +489,11 @@ struct BasicTest { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, B exp_batches) { - l_batches = MutateByKey(l_batches, "key", "key2", true); - r0_batches = MutateByKey(r0_batches, "key", "key2", true); - r1_batches = MutateByKey(r1_batches, "key", "key2", true); - exp_nokey_batches = MutateByKey(exp_nokey_batches, "key", "key2", true); + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(exp_nokey_batches, + MutateByKey(exp_nokey_batches, "key", "key2", true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, "time", "key2", tolerance); }); @@ -504,10 +506,13 @@ struct BasicTest { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, B exp_batches) { - l_batches = MutateByKey(l_batches, "key", "key2", true, true); - r0_batches = MutateByKey(r0_batches, "key", "key2", true, true); - r1_batches = MutateByKey(r1_batches, "key", "key2", true, true); - exp_nokey_batches = MutateByKey(exp_nokey_batches, "key", "key2", true, true); + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(r0_batches, + MutateByKey(r0_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(r1_batches, + MutateByKey(r1_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(exp_nokey_batches, + MutateByKey(exp_nokey_batches, "key", "key2", true, true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, AsofJoinNodeOptions("time", "key2", tolerance, /*nullable_by_key=*/true)); From 519f6fa52ceb92a36cf52fafd75fdee2bfc211ef Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 25 Aug 2022 18:11:09 -0400 Subject: [PATCH 17/26] clean up tests --- .../arrow/compute/exec/asof_join_node_test.cc | 254 +++++++++++------- 1 file changed, 156 insertions(+), 98 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index aaf5ddd151c88..5487a3d18ade7 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -37,13 +37,17 @@ #include "arrow/util/string_view.h" #include "arrow/util/thread_pool.h" -#define TRACED_TEST(t_class, t_name) \ - static void _##t_class##_##t_name(); \ +#define TRACED_TEST(t_class, t_name, t_body) \ TEST(t_class, t_name) { \ ARROW_SCOPED_TRACE(#t_class "_" #t_name); \ - _##t_class##_##t_name(); \ - } \ - static void _##t_class##_##t_name() + t_body; \ + } + +#define TRACED_TEST_P(t_class, t_name, t_body) \ + TEST_P(t_class, t_name) { \ + ARROW_SCOPED_TRACE(#t_class "_" #t_name "_" + std::get<1>(GetParam())); \ + t_body; \ + } using testing::UnorderedElementsAreArray; @@ -428,24 +432,24 @@ struct BasicTest { exp_data(std::move(exp_data)), tolerance(tolerance) {} + static inline void check_init(const std::vector>& types) { + ASSERT_NE(0, types.size()); + } + template - static inline void init_types(const std::vector>& all_types, - std::vector>& types, - TypeCond type_cond) { - if (types.size() == 0) { - for (auto type : all_types) { - if (type_cond(type)) { - types.push_back(type); - } + static inline std::vector> init_types( + const std::vector>& all_types, TypeCond type_cond) { + std::vector> types; + for (auto type : all_types) { + if (type_cond(type)) { + types.push_back(type); } } + check_init(types); + return types; } - void RunSingleByKey(std::vector> time_types = {}, - std::vector> key_types = {}, - std::vector> l_types = {}, - std::vector> r0_types = {}, - std::vector> r1_types = {}) { + void RunSingleByKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, B exp_batches) { @@ -453,11 +457,8 @@ struct BasicTest { tolerance); }); } - void RunDoubleByKey(std::vector> time_types = {}, - std::vector> key_types = {}, - std::vector> l_types = {}, - std::vector> r0_types = {}, - std::vector> r1_types = {}) { + static void DoSingleByKey(BasicTest& basic_tests) { basic_tests.RunSingleByKey(); } + void RunDoubleByKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, B exp_batches) { @@ -465,11 +466,8 @@ struct BasicTest { {"key", "key"}, tolerance); }); } - void RunMutateByKey(std::vector> time_types = {}, - std::vector> key_types = {}, - std::vector> l_types = {}, - std::vector> r0_types = {}, - std::vector> r1_types = {}) { + static void DoDoubleByKey(BasicTest& basic_tests) { basic_tests.RunDoubleByKey(); } + void RunMutateByKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, B exp_batches) { @@ -481,11 +479,8 @@ struct BasicTest { {"key", "key2"}, tolerance); }); } - void RunMutateNoKey(std::vector> time_types = {}, - std::vector> key_types = {}, - std::vector> l_types = {}, - std::vector> r0_types = {}, - std::vector> r1_types = {}) { + static void DoMutateByKey(BasicTest& basic_tests) { basic_tests.RunMutateByKey(); } + void RunMutateNoKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, B exp_batches) { @@ -498,11 +493,8 @@ struct BasicTest { tolerance); }); } - void RunMutateNullKey(std::vector> time_types = {}, - std::vector> key_types = {}, - std::vector> l_types = {}, - std::vector> r0_types = {}, - std::vector> r1_types = {}) { + static void DoMutateNoKey(BasicTest& basic_tests) { basic_tests.RunMutateNoKey(); } + void RunMutateNullKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, B exp_batches) { @@ -518,13 +510,9 @@ struct BasicTest { /*nullable_by_key=*/true)); }); } + static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); } template - void RunBatches(BatchesRunner batches_runner, - std::vector> time_types = {}, - std::vector> key_types = {}, - std::vector> l_types = {}, - std::vector> r0_types = {}, - std::vector> r1_types = {}) { + void RunBatches(BatchesRunner batches_runner) { std::vector> all_types = { utf8(), large_utf8(), @@ -552,17 +540,12 @@ struct BasicTest { float64()}; using T = const std::shared_ptr; // byte_width > 1 below allows fitting the tested data - init_types(all_types, time_types, - [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); - ASSERT_NE(0, time_types.size()); - init_types(all_types, key_types, [](T& t) { return !is_floating(t->id()); }); - ASSERT_NE(0, key_types.size()); - init_types(all_types, l_types, [](T& t) { return true; }); - ASSERT_NE(0, l_types.size()); - init_types(all_types, r0_types, [](T& t) { return t->byte_width() > 1; }); - ASSERT_NE(0, r0_types.size()); - init_types(all_types, r1_types, [](T& t) { return t->byte_width() > 1; }); - ASSERT_NE(0, r1_types.size()); + auto time_types = init_types( + all_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); + auto key_types = init_types(all_types, [](T& t) { return !is_floating(t->id()); }); + auto l_types = init_types(all_types, [](T& t) { return true; }); + auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); + auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); // sample a limited number of type-combinations to keep the runnning time reasonable // the scoped-traces below help reproduce a test failure, should it happen @@ -636,6 +619,10 @@ struct BasicTest { int64_t tolerance; }; +using AsofJoinBasicParams = std::tuple, std::string>; + +struct AsofJoinBasicTest : public testing::TestWithParam {}; + class AsofJoinTest : public testing::Test {}; #define ASOFJOIN_TEST_SET(name, num) \ @@ -664,7 +651,12 @@ BasicTest GetBasicTest1() { /*exp_nokey*/ {R"([[0, 0, 1, 11, null], [1000, 0, 2, 11, 101]])"}, /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, 1000); } -ASOFJOIN_TEST_SET(Basic, 1) + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic1, { + BasicTest basic_test = GetBasicTest1(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetBasicTest2() { // Single key, multiple batches @@ -675,7 +667,12 @@ BasicTest GetBasicTest2() { /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } -ASOFJOIN_TEST_SET(Basic, 2) + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic2, { + BasicTest basic_test = GetBasicTest2(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetBasicTest3() { // Single key, multiple left batches, single right batches @@ -686,7 +683,13 @@ BasicTest GetBasicTest3() { /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } -ASOFJOIN_TEST_SET(Basic, 3) + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic3, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic3_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest3(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetBasicTest4() { // Multi key, multiple batches, misaligned batches @@ -708,7 +711,12 @@ BasicTest GetBasicTest4() { R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } -ASOFJOIN_TEST_SET(Basic, 4) + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic4, { + BasicTest basic_test = GetBasicTest4(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetBasicTest5() { // Multi key, multiple batches, misaligned batches, smaller tolerance @@ -729,7 +737,13 @@ BasicTest GetBasicTest5() { R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 500); } -ASOFJOIN_TEST_SET(Basic, 5) + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic5, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic5_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest5(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetBasicTest6() { // Multi key, multiple batches, misaligned batches, zero tolerance @@ -750,7 +764,13 @@ BasicTest GetBasicTest6() { R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"}, 0); } -ASOFJOIN_TEST_SET(Basic, 6) + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic6, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic6_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest6(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetEmptyTest1() { // Empty left batch @@ -767,7 +787,13 @@ BasicTest GetEmptyTest1() { /*exp*/ {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } -ASOFJOIN_TEST_SET(Empty, 1) + +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty1, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty1_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest1(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetEmptyTest2() { // Empty left input @@ -784,7 +810,13 @@ BasicTest GetEmptyTest2() { /*exp*/ {R"([])"}, 1000); } -ASOFJOIN_TEST_SET(Empty, 2) + +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty2, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty2_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest2(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetEmptyTest3() { // Empty right batch @@ -804,7 +836,13 @@ BasicTest GetEmptyTest3() { R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } -ASOFJOIN_TEST_SET(Empty, 3) + +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty3, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty3_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest3(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetEmptyTest4() { // Empty right input @@ -824,7 +862,13 @@ BasicTest GetEmptyTest4() { R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"}, 1000); } -ASOFJOIN_TEST_SET(Empty, 4) + +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty4, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty4_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest4(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) BasicTest GetEmptyTest5() { // All empty @@ -839,31 +883,45 @@ BasicTest GetEmptyTest5() { /*exp*/ {R"([])"}, 1000); } -ASOFJOIN_TEST_SET(Empty, 5) -TRACED_TEST(AsofJoinTest, TestUnsupportedOntype) { +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty5, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty5_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest5(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +INSTANTIATE_TEST_SUITE_P( + AsofJoinNodeTest, AsofJoinBasicTest, + testing::Values(AsofJoinBasicParams(BasicTest::DoSingleByKey, "single by-key"), + AsofJoinBasicParams(BasicTest::DoDoubleByKey, "double by-key"), + AsofJoinBasicParams(BasicTest::DoMutateByKey, "mutate by-key"), + AsofJoinBasicParams(BasicTest::DoMutateNoKey, "mutate no-key"), + AsofJoinBasicParams(BasicTest::DoMutateNullKey, "mutate null-key"))); + +TRACED_TEST(AsofJoinTest, TestUnsupportedOntype, { DoRunInvalidTypeTest(schema({field("time", list(int32())), field("key", int32()), field("l_v0", float64())}), schema({field("time", list(int32())), field("key", int32()), field("r0_v0", float32())})); -} +}) -TRACED_TEST(AsofJoinTest, TestUnsupportedBytype) { +TRACED_TEST(AsofJoinTest, TestUnsupportedBytype, { DoRunInvalidTypeTest(schema({field("time", int64()), field("key", list(int32())), field("l_v0", float64())}), schema({field("time", int64()), field("key", list(int32())), field("r0_v0", float32())})); -} +}) -TRACED_TEST(AsofJoinTest, TestUnsupportedDatatype) { +TRACED_TEST(AsofJoinTest, TestUnsupportedDatatype, { // List is unsupported DoRunInvalidTypeTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", list(int32()))})); -} +}) -TRACED_TEST(AsofJoinTest, TestMissingKeys) { +TRACED_TEST(AsofJoinTest, TestMissingKeys, { DoRunMissingKeysTest( schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}), schema( @@ -873,85 +931,85 @@ TRACED_TEST(AsofJoinTest, TestMissingKeys) { schema({field("time", int64()), field("key1", int32()), field("l_v0", float64())}), schema( {field("time", int64()), field("key1", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestUnsupportedTolerance) { +TRACED_TEST(AsofJoinTest, TestUnsupportedTolerance, { // Utf8 is unsupported DoRunInvalidToleranceTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestEmptyByKey) { +TRACED_TEST(AsofJoinTest, TestEmptyByKey, { DoRunEmptyByKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestMissingOnKey) { +TRACED_TEST(AsofJoinTest, TestMissingOnKey, { DoRunMissingOnKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestMissingByKey) { +TRACED_TEST(AsofJoinTest, TestMissingByKey, { DoRunMissingByKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestNestedOnKey) { +TRACED_TEST(AsofJoinTest, TestNestedOnKey, { DoRunNestedOnKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestNestedByKey) { +TRACED_TEST(AsofJoinTest, TestNestedByKey, { DoRunNestedByKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestAmbiguousOnKey) { +TRACED_TEST(AsofJoinTest, TestAmbiguousOnKey, { DoRunAmbiguousOnKeyTest( schema({field("time", int64()), field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestAmbiguousByKey) { +TRACED_TEST(AsofJoinTest, TestAmbiguousByKey, { DoRunAmbiguousByKeyTest( schema({field("time", int64()), field("key", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestLeftUnorderedOnKey) { +TRACED_TEST(AsofJoinTest, TestLeftUnorderedOnKey, { DoRunUnorderedPlanTest( /*l_unordered=*/true, /*r_unordered=*/false, schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestRightUnorderedOnKey) { +TRACED_TEST(AsofJoinTest, TestRightUnorderedOnKey, { DoRunUnorderedPlanTest( /*l_unordered=*/false, /*r_unordered=*/true, schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestUnorderedOnKey) { +TRACED_TEST(AsofJoinTest, TestUnorderedOnKey, { DoRunUnorderedPlanTest( /*l_unordered=*/true, /*r_unordered=*/true, schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) -TRACED_TEST(AsofJoinTest, TestNullByKey) { +TRACED_TEST(AsofJoinTest, TestNullByKey, { DoRunNullByKeyPlanTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -} +}) } // namespace compute } // namespace arrow From 494c57c62b0b81ba03309ec0ec3e737136534579 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sat, 27 Aug 2022 04:20:02 -0400 Subject: [PATCH 18/26] more fixes --- cpp/src/arrow/array/data.h | 5 ++++ cpp/src/arrow/compute/exec/asof_join_node.cc | 27 +++----------------- cpp/src/arrow/type_traits.h | 7 +++++ 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index dde66ac79c44b..9190b8d5bea5f 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -167,6 +167,11 @@ struct ARROW_EXPORT ArrayData { std::shared_ptr Copy() const { return std::make_shared(*this); } + bool IsNull(int64_t i) const { + return ((buffers[0] != NULLPTR) ? !bit_util::GetBit(buffers[0]->data(), i + offset) + : null_count.load() == length); + } + // Access a buffer's data as a typed C pointer template inline const T* GetValues(int i, int64_t absolute_offset) const { diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 7ec717795ec9c..e8b0281aae0bc 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -63,21 +63,15 @@ typedef int col_index_t; // normalize the value to 64-bits while preserving ordering of values template ::value, bool> = true> -static inline uint64_t norm_value(T t) { +static inline uint64_t time_value(T t) { uint64_t bias = std::is_signed::value ? (uint64_t)1 << (8 * sizeof(T) - 1) : 0; return t < 0 ? static_cast(t + bias) : static_cast(t); } -// indicates normalization of a time value -template ::value, bool> = true> -static inline uint64_t time_value(T t) { - return norm_value(t); -} - // indicates normalization of a key value template ::value, bool> = true> static inline uint64_t key_value(T t) { - return norm_value(t); + return static_cast(t); } /** @@ -642,23 +636,10 @@ class CompositeReferenceTable { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } - // this should really be a method on ArrayData - static bool IsNull(const std::shared_ptr& source, row_index_t row) { - return ((source->buffers[0] != NULLPTR) - ? !bit_util::GetBit(source->buffers[0]->data(), row + source->offset) - : source->null_count.load() == source->length); - } - - template - using is_fixed_width_type = std::is_base_of; - - template - using enable_if_fixed_width_type = enable_if_t::value, R>; - template ::BuilderType> enable_if_fixed_width_type static BuilderAppend( Builder& builder, const std::shared_ptr& source, row_index_t row) { - if (IsNull(source, row)) { + if (source->IsNull(row)) { builder.UnsafeAppendNull(); return Status::OK(); } @@ -670,7 +651,7 @@ class CompositeReferenceTable { template ::BuilderType> enable_if_base_binary static BuilderAppend( Builder& builder, const std::shared_ptr& source, row_index_t row) { - if (IsNull(source, row)) { + if (source->IsNull(row)) { return builder.AppendNull(); } using offset_type = typename Type::offset_type; diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 221b35ce57323..8523c984e89f2 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -622,6 +622,13 @@ using is_fixed_size_binary_type = std::is_base_of; template using enable_if_fixed_size_binary = enable_if_t::value, R>; +// This includes primitive, dictionary, and fixed-size-binary types +template +using is_fixed_width_type = std::is_base_of; + +template +using enable_if_fixed_width_type = enable_if_t::value, R>; + template using is_binary_like_type = std::integral_constant::value && From 3d4afda7061bc302b591ec3809daeb0d63d4c1a9 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 28 Aug 2022 11:40:54 -0400 Subject: [PATCH 19/26] AsofJoin empty-key support and tests --- cpp/src/arrow/compute/exec/asof_join_node.cc | 24 ++-- .../arrow/compute/exec/asof_join_node_test.cc | 117 +++++++++++------- 2 files changed, 88 insertions(+), 53 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index e8b0281aae0bc..a4624edd3ecfa 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -241,7 +241,7 @@ class InputState { time_col_index_(time_col_index), key_col_index_(key_col_index), time_type_id_(schema_->fields()[time_col_index_]->type()->id()), - key_type_id_(schema_->num_fields()), + key_type_id_(key_col_index.size()), key_hasher_(key_hasher), must_hash_(must_hash), nullable_by_key_(nullable_by_key) { @@ -292,6 +292,9 @@ class InputState { } ByType GetLatestKey() const { + if (key_col_index_.size() == 0) { + return 0; + } const RecordBatch* batch = queue_.UnsyncFront().get(); if (must_hash_) { return key_hasher_->HashesFor(batch)[latest_ref_row_]; @@ -398,7 +401,8 @@ class InputState { } Status Push(const std::shared_ptr& rb) { - if (!nullable_by_key_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { + if (!nullable_by_key_ && key_col_index_.size() == 1 && + rb->column_data(key_col_index_[0])->GetNullCount() > 0) { return Status::Invalid("AsofJoin does not allow unexpected null by-key values"); } if (rb->num_rows() > 0) { @@ -999,9 +1003,6 @@ class AsofJoinNode : public ExecNode { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; const auto& join_options = checked_cast(options); - if (join_options.by_key.size() == 0) { - return Status::Invalid("AsOfJoin by_key must not be empty"); - } if (join_options.tolerance < 0) { return Status::Invalid("AsOfJoin tolerance must be non-negative but is ", join_options.tolerance); @@ -1032,9 +1033,11 @@ class AsofJoinNode : public ExecNode { std::move(output_schema)); auto node_output_schema = node->output_schema(); auto node_indices_of_by_key = checked_cast(node)->indices_of_by_key(); - auto single_key_field = inputs[0]->output_schema()->field(indices_of_by_key[0][0]); std::vector> key_hashers; - bool must_hash = n_by > 1 || !is_primitive(single_key_field->type()->id()); + bool must_hash = + n_by != 1 || + !is_primitive( + inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id()); bool nullable_by_key = join_options.nullable_by_key; for (size_t i = 0; i < n_input; i++) { key_hashers.push_back( @@ -1088,7 +1091,12 @@ class AsofJoinNode : public ExecNode { DCHECK_EQ(output, outputs_[0]); StopProducing(); } - void StopProducing() override { finished_.MarkFinished(); } + void StopProducing() override { + // avoid finishing twice, to prevent "Plan was destroyed before finishing" error + if (finished_.state() == FutureState::PENDING) { + finished_.MarkFinished(); + } + } arrow::Future<> finished() override { return finished_; } private: diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 5487a3d18ade7..f586c67a972bb 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -131,20 +131,27 @@ void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { // mutates by copying from_key into to_key and changing from_key to zero Result MutateByKey(BatchesWithSchema& batches, std::string from_key, std::string to_key, bool replace_key = false, - bool null_key = false) { + bool null_key = false, bool remove_key = false) { int from_index = batches.schema->GetFieldIndex(from_key); int n_fields = batches.schema->num_fields(); auto fields = batches.schema->fields(); BatchesWithSchema new_batches; - auto new_field = batches.schema->field(from_index)->WithName(to_key); - ARROW_ASSIGN_OR_RAISE(new_batches.schema, - replace_key ? batches.schema->SetField(from_index, new_field) - : batches.schema->AddField(from_index, new_field)); + if (remove_key) { + ARROW_ASSIGN_OR_RAISE(new_batches.schema, batches.schema->RemoveField(from_index)); + } else { + auto new_field = batches.schema->field(from_index)->WithName(to_key); + ARROW_ASSIGN_OR_RAISE(new_batches.schema, + replace_key ? batches.schema->SetField(from_index, new_field) + : batches.schema->AddField(from_index, new_field)); + } for (const ExecBatch& batch : batches.batches) { std::vector new_values; for (int i = 0; i < n_fields; i++) { const Datum& value = batch.values[i]; if (i == from_index) { + if (remove_key) { + continue; + } auto type = fields[i]->type(); if (null_key) { std::shared_ptr empty; @@ -221,6 +228,9 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, .AddToPlan(plan.get())); ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + for (auto batch : res) { + ASSERT_EQ(exp_batches.schema->num_fields(), batch.values.size()); + } ASSERT_OK_AND_ASSIGN(auto exp_table, TableFromExecBatches(exp_batches.schema, exp_batches.batches)); @@ -309,12 +319,6 @@ void DoRunMissingKeysTest(const std::shared_ptr& l_schema, DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No match"); } -void DoRunEmptyByKeyTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", {}, 0), - "AsOfJoin by_key must not be empty"); -} - void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("invalid_time", "key", 0), @@ -424,11 +428,13 @@ struct BasicTest { const std::vector& r0_data, const std::vector& r1_data, const std::vector& exp_nokey_data, + const std::vector& exp_emptykey_data, const std::vector& exp_data, int64_t tolerance) : l_data(std::move(l_data)), r0_data(std::move(r0_data)), r1_data(std::move(r1_data)), exp_nokey_data(std::move(exp_nokey_data)), + exp_emptykey_data(std::move(exp_emptykey_data)), exp_data(std::move(exp_data)), tolerance(tolerance) {} @@ -452,7 +458,7 @@ struct BasicTest { void RunSingleByKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, - B exp_batches) { + B exp_emptykey_batches, B exp_batches) { CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", tolerance); }); @@ -461,7 +467,7 @@ struct BasicTest { void RunDoubleByKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, - B exp_batches) { + B exp_emptykey_batches, B exp_batches) { CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", {"key", "key"}, tolerance); }); @@ -470,7 +476,7 @@ struct BasicTest { void RunMutateByKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, - B exp_batches) { + B exp_emptykey_batches, B exp_batches) { ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2")); ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2")); ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2")); @@ -483,7 +489,7 @@ struct BasicTest { void RunMutateNoKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, - B exp_batches) { + B exp_emptykey_batches, B exp_batches) { ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true)); ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2", true)); ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2", true)); @@ -497,7 +503,7 @@ struct BasicTest { void RunMutateNullKey() { using B = BatchesWithSchema; RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, - B exp_batches) { + B exp_emptykey_batches, B exp_batches) { ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true, true)); ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2", true, true)); @@ -511,6 +517,21 @@ struct BasicTest { }); } static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); } + void RunMutateEmptyKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(r0_batches, + MutateByKey(r0_batches, "key", "key", false, false, true)); + ASSERT_OK_AND_ASSIGN(r1_batches, + MutateByKey(r1_batches, "key", "key", false, false, true)); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches, + AsofJoinNodeOptions("time", {}, tolerance)); + }); + } + static void DoMutateEmptyKey(BasicTest& basic_tests) { + basic_tests.RunMutateEmptyKey(); + } template void RunBatches(BatchesRunner batches_runner) { std::vector> all_types = { @@ -606,15 +627,19 @@ struct BasicTest { ASSERT_OK_AND_ASSIGN(auto r1_batches, MakeBatchesFromNumString(r1_schema, r1_data)); ASSERT_OK_AND_ASSIGN(auto exp_nokey_batches, MakeBatchesFromNumString(exp_schema, exp_nokey_data)); + ASSERT_OK_AND_ASSIGN(auto exp_emptykey_batches, + MakeBatchesFromNumString(exp_schema, exp_emptykey_data)); ASSERT_OK_AND_ASSIGN(auto exp_batches, MakeBatchesFromNumString(exp_schema, exp_data)); - batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches, exp_batches); + batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches, + exp_emptykey_batches, exp_batches); } std::vector l_data; std::vector r0_data; std::vector r1_data; std::vector exp_nokey_data; + std::vector exp_emptykey_data; std::vector exp_data; int64_t tolerance; }; @@ -625,23 +650,6 @@ struct AsofJoinBasicTest : public testing::TestWithParam {} class AsofJoinTest : public testing::Test {}; -#define ASOFJOIN_TEST_SET(name, num) \ - TRACED_TEST(AsofJoinTest, Test##name##num##_SingleByKey) { \ - Get##name##Test##num().RunSingleByKey(); \ - } \ - TRACED_TEST(AsofJoinTest, Test##name##num##_DoubleByKey) { \ - Get##name##Test##num().RunDoubleByKey(); \ - } \ - TRACED_TEST(AsofJoinTest, Test##name##num##_MutateByKey) { \ - Get##name##Test##num().RunMutateByKey(); \ - } \ - TRACED_TEST(AsofJoinTest, Test##name##num##_MutateNoKey) { \ - Get##name##Test##num().RunMutateNoKey(); \ - } \ - TRACED_TEST(AsofJoinTest, Test##name##num##_MutateNullKey) { \ - Get##name##Test##num().RunMutateNullKey(); \ - } - BasicTest GetBasicTest1() { // Single key, single batch return BasicTest( @@ -649,6 +657,7 @@ BasicTest GetBasicTest1() { /*r0*/ {R"([[0, 1, 11]])"}, /*r1*/ {R"([[1000, 1, 101]])"}, /*exp_nokey*/ {R"([[0, 0, 1, 11, null], [1000, 0, 2, 11, 101]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, 1000); } @@ -665,6 +674,7 @@ BasicTest GetBasicTest2() { /*r0*/ {R"([[0, 1, 11]])", R"([[1000, 1, 12]])"}, /*r1*/ {R"([[0, 1, 101]])", R"([[1000, 1, 102]])"}, /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } @@ -681,6 +691,7 @@ BasicTest GetBasicTest3() { /*r0*/ {R"([[0, 1, 11], [1000, 1, 12]])"}, /*r1*/ {R"([[0, 1, 101], [1000, 1, 102]])"}, /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } @@ -706,6 +717,9 @@ BasicTest GetBasicTest4() { /*exp_nokey*/ {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, /*exp*/ {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, 1001], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, @@ -732,6 +746,9 @@ BasicTest GetBasicTest5() { /*exp_nokey*/ {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, /*exp*/ {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, null], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, @@ -759,6 +776,9 @@ BasicTest GetBasicTest6() { /*exp_nokey*/ {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, /*exp*/ {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, null], [1500, 1, 3, null, null], [1500, 2, 23, 32, 1002]])", R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"}, @@ -784,6 +804,8 @@ BasicTest GetEmptyTest1() { R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp_nokey*/ {R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, /*exp*/ {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } @@ -807,6 +829,8 @@ BasicTest GetEmptyTest2() { R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, /*exp_nokey*/ {R"([])"}, + /*exp_emptykey*/ + {R"([])"}, /*exp*/ {R"([])"}, 1000); } @@ -831,6 +855,9 @@ BasicTest GetEmptyTest3() { /*exp_nokey*/ {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, /*exp*/ {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, 32, 1002]])", R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, @@ -857,6 +884,9 @@ BasicTest GetEmptyTest4() { /*exp_nokey*/ {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, null, 1002], [1500, 0, 23, null, 1002]])", R"([[2000, 0, 4, null, 103], [2000, 0, 24, null, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, null, 1002], [1500, 2, 23, null, 1002]])", + R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 103]])"}, /*exp*/ {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, null, 1002]])", R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"}, @@ -880,6 +910,8 @@ BasicTest GetEmptyTest5() { {R"([])"}, /*exp_nokey*/ {R"([])"}, + /*exp_emptykey*/ + {R"([])"}, /*exp*/ {R"([])"}, 1000); } @@ -893,11 +925,12 @@ TRACED_TEST_P(AsofJoinBasicTest, TestEmpty5, { INSTANTIATE_TEST_SUITE_P( AsofJoinNodeTest, AsofJoinBasicTest, - testing::Values(AsofJoinBasicParams(BasicTest::DoSingleByKey, "single by-key"), - AsofJoinBasicParams(BasicTest::DoDoubleByKey, "double by-key"), - AsofJoinBasicParams(BasicTest::DoMutateByKey, "mutate by-key"), - AsofJoinBasicParams(BasicTest::DoMutateNoKey, "mutate no-key"), - AsofJoinBasicParams(BasicTest::DoMutateNullKey, "mutate null-key"))); + testing::Values(AsofJoinBasicParams(BasicTest::DoSingleByKey, "SingleByKey"), + AsofJoinBasicParams(BasicTest::DoDoubleByKey, "DoubleByKey"), + AsofJoinBasicParams(BasicTest::DoMutateByKey, "MutateByKey"), + AsofJoinBasicParams(BasicTest::DoMutateNoKey, "MutateNoKey"), + AsofJoinBasicParams(BasicTest::DoMutateNullKey, "MutateNullKey"), + AsofJoinBasicParams(BasicTest::DoMutateEmptyKey, "MutateEmptyKey"))); TRACED_TEST(AsofJoinTest, TestUnsupportedOntype, { DoRunInvalidTypeTest(schema({field("time", list(int32())), field("key", int32()), @@ -940,12 +973,6 @@ TRACED_TEST(AsofJoinTest, TestUnsupportedTolerance, { schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); }) -TRACED_TEST(AsofJoinTest, TestEmptyByKey, { - DoRunEmptyByKeyTest( - schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), - schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -}) - TRACED_TEST(AsofJoinTest, TestMissingOnKey, { DoRunMissingOnKeyTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), From 053cbad1405f96fb9a2e91271225214008d0192d Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 28 Aug 2022 15:22:31 -0400 Subject: [PATCH 20/26] fix compilation warning --- cpp/src/arrow/compute/exec/asof_join_node_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index f586c67a972bb..2de638535b768 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -357,6 +357,7 @@ void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, // The batch will have n_rows rows n_cols columns, the first column being the on-field // If unordered is true then the first column will be out-of-order std::string GetTestBatchAsJsonString(int n_rows, int n_cols, bool unordered = false) { + int order_mask = unordered ? 1 : 0; std::stringstream s; s << '['; for (int i = 0; i < n_rows; i++) { @@ -368,7 +369,7 @@ std::string GetTestBatchAsJsonString(int n_rows, int n_cols, bool unordered = fa if (j > 0) { s << ", " << j; } else if (j < 2) { - s << (i ^ unordered); + s << (i ^ order_mask); } else { s << i; } From dad8e36d423ec8e341e77c6c50a2d4644573f20f Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 30 Aug 2022 10:19:19 -0400 Subject: [PATCH 21/26] AsofJoinNode initialization cleanup --- cpp/src/arrow/compute/exec/asof_join_node.cc | 47 +++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index a4624edd3ecfa..802287c6cfe07 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -815,21 +815,24 @@ class AsofJoinNode : public ExecNode { AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const std::vector& indices_of_on_key, const std::vector>& indices_of_by_key, - OnType tolerance, std::shared_ptr output_schema); + OnType tolerance, std::shared_ptr output_schema, + std::vector> key_hashers, bool must_hash, + bool nullable_by_key); - void InternalInit(bool must_hash, bool nullable_by_key, - std::vector> key_hashers) { - key_hashers_ = std::move(key_hashers); + Status Init() { auto inputs = this->inputs(); - for (size_t i = 0; i < inputs.size(); ++i) { + for (size_t i = 0; i < inputs.size(); i++) { + RETURN_NOT_OK(key_hashers_[i]->Init(plan()->exec_context(), output_schema())); state_.push_back(::arrow::internal::make_unique( - must_hash, nullable_by_key, key_hashers_[i].get(), inputs[i]->output_schema(), + must_hash_, nullable_by_key_, key_hashers_[i].get(), inputs[i]->output_schema(), indices_of_on_key_[i], indices_of_by_key_[i])); } col_index_t dst_offset = 0; for (auto& state : state_) dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + + return Status::OK(); } virtual ~AsofJoinNode() { @@ -1027,25 +1030,20 @@ class AsofJoinNode : public ExecNode { ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key)); - auto node = plan->EmplaceNode( - plan, inputs, std::move(input_labels), std::move(indices_of_on_key), - std::move(indices_of_by_key), time_value(join_options.tolerance), - std::move(output_schema)); - auto node_output_schema = node->output_schema(); - auto node_indices_of_by_key = checked_cast(node)->indices_of_by_key(); std::vector> key_hashers; + for (size_t i = 0; i < n_input; i++) { + key_hashers.push_back( + ::arrow::internal::make_unique(indices_of_by_key[i])); + } bool must_hash = n_by != 1 || !is_primitive( inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id()); bool nullable_by_key = join_options.nullable_by_key; - for (size_t i = 0; i < n_input; i++) { - key_hashers.push_back( - ::arrow::internal::make_unique(node_indices_of_by_key[i])); - RETURN_NOT_OK(key_hashers[i]->Init(plan->exec_context(), node_output_schema)); - } - node->InternalInit(must_hash, nullable_by_key, std::move(key_hashers)); - return node; + return plan->EmplaceNode( + plan, inputs, std::move(input_labels), std::move(indices_of_on_key), + std::move(indices_of_by_key), time_value(join_options.tolerance), + std::move(output_schema), std::move(key_hashers), must_hash, nullable_by_key); } const char* kind_name() const override { return "AsofJoinNode"; } @@ -1101,9 +1099,11 @@ class AsofJoinNode : public ExecNode { private: arrow::Future<> finished_; - std::vector> key_hashers_; std::vector indices_of_on_key_; std::vector> indices_of_by_key_; + std::vector> key_hashers_; + bool must_hash_; + bool nullable_by_key_; // InputStates // Each input state correponds to an input table std::vector> state_; @@ -1124,12 +1124,17 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const std::vector& indices_of_on_key, const std::vector>& indices_of_by_key, - OnType tolerance, std::shared_ptr output_schema) + OnType tolerance, std::shared_ptr output_schema, + std::vector> key_hashers, + bool must_hash, bool nullable_by_key) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), indices_of_on_key_(std::move(indices_of_on_key)), indices_of_by_key_(std::move(indices_of_by_key)), + key_hashers_(std::move(key_hashers)), + must_hash_(must_hash), + nullable_by_key_(nullable_by_key), tolerance_(tolerance), process_(), process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { From 8bee2346717c217d2f1e233a1273577d0df687eb Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 31 Aug 2022 16:32:27 -0400 Subject: [PATCH 22/26] fix override --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 802287c6cfe07..4817fd2cc70c0 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -819,7 +819,7 @@ class AsofJoinNode : public ExecNode { std::vector> key_hashers, bool must_hash, bool nullable_by_key); - Status Init() { + Status Init() override { auto inputs = this->inputs(); for (size_t i = 0; i < inputs.size(); i++) { RETURN_NOT_OK(key_hashers_[i]->Init(plan()->exec_context(), output_schema())); From 9abb02e7724cfc834402dd377c38a703afc779fe Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 2 Sep 2022 05:05:20 -0400 Subject: [PATCH 23/26] rehashing on null by-key, cleaner user API --- cpp/src/arrow/array/data.h | 10 +-- .../arrow/compute/exec/asof_join_benchmark.cc | 2 +- cpp/src/arrow/compute/exec/asof_join_node.cc | 81 ++++++++++++------- .../arrow/compute/exec/asof_join_node_test.cc | 45 +++-------- cpp/src/arrow/compute/exec/options.h | 29 +------ 5 files changed, 69 insertions(+), 98 deletions(-) diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index 9190b8d5bea5f..e024483f66551 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -329,18 +329,14 @@ struct ARROW_EXPORT ArraySpan { return GetValues(i, this->offset); } - bool IsNull(int64_t i) const { - return ((this->buffers[0].data != NULLPTR) - ? !bit_util::GetBit(this->buffers[0].data, i + this->offset) - : this->null_count == this->length); - } - - bool IsValid(int64_t i) const { + inline bool IsValid(int64_t i) const { return ((this->buffers[0].data != NULLPTR) ? bit_util::GetBit(this->buffers[0].data, i + this->offset) : this->null_count != this->length); } + inline bool IsNull(int64_t i) const { return !IsValid(i); } + std::shared_ptr ToArrayData() const; std::shared_ptr ToArray() const; diff --git a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc index 543a4ece575bb..7d8abc0ba4c14 100644 --- a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc @@ -109,7 +109,7 @@ static void TableJoinOverhead(benchmark::State& state, static void AsOfJoinOverhead(benchmark::State& state) { int64_t tolerance = 0; - AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, kKeyCol, tolerance); + AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, {kKeyCol}, tolerance); TableJoinOverhead( state, TableGenerationProperties{int(state.range(0)), int(state.range(1)), diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 4817fd2cc70c0..d4e704ba623f2 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -232,7 +232,7 @@ class InputState { // turned into output record batches. public: - InputState(bool must_hash, bool nullable_by_key, KeyHasher* key_hasher, + InputState(bool must_hash, bool may_rehash, KeyHasher* key_hasher, const std::shared_ptr& schema, const col_index_t time_col_index, const std::vector& key_col_index) @@ -244,7 +244,7 @@ class InputState { key_type_id_(key_col_index.size()), key_hasher_(key_hasher), must_hash_(must_hash), - nullable_by_key_(nullable_by_key) { + may_rehash_(may_rehash) { for (size_t k = 0; k < key_col_index_.size(); k++) { key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); } @@ -284,21 +284,24 @@ class InputState { return queue_.UnsyncFront(); } -#define LATEST_VAL_CASE(id, val) \ - case Type::id: { \ - using T = typename TypeIdTraits::Type; \ - using CType = typename TypeTraits::CType; \ - return val(data->GetValues(1)[latest_ref_row_]); \ +#define LATEST_VAL_CASE(id, val) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + using CType = typename TypeTraits::CType; \ + return val(data->GetValues(1)[row]); \ } - ByType GetLatestKey() const { + inline ByType GetLatestKey() const { + return GetLatestKey(queue_.UnsyncFront().get(), latest_ref_row_); + } + + inline ByType GetLatestKey(const RecordBatch* batch, row_index_t row) const { + if (must_hash_) { + return key_hasher_->HashesFor(batch)[row]; + } if (key_col_index_.size() == 0) { return 0; } - const RecordBatch* batch = queue_.UnsyncFront().get(); - if (must_hash_) { - return key_hasher_->HashesFor(batch)[latest_ref_row_]; - } auto data = batch->column_data(key_col_index_[0]); switch (key_type_id_[0]) { LATEST_VAL_CASE(INT8, key_value) @@ -320,8 +323,12 @@ class InputState { } } - OnType GetLatestTime() const { - auto data = queue_.UnsyncFront()->column_data(time_col_index_); + inline OnType GetLatestTime() const { + return GetLatestTime(queue_.UnsyncFront().get(), latest_ref_row_); + } + + inline ByType GetLatestTime(const RecordBatch* batch, row_index_t row) const { + auto data = batch->column_data(time_col_index_); switch (time_type_id_) { LATEST_VAL_CASE(INT8, time_value) LATEST_VAL_CASE(INT16, time_value) @@ -393,18 +400,29 @@ class InputState { if (latest_time > ts) { break; // hit a future timestamp -- done updating for now } - memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); + auto rb = GetLatestBatch(); + if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { + must_hash_ = true; + may_rehash_ = false; + Rehash(); + } + memo_.Store(rb, latest_ref_row_, latest_time, GetLatestKey()); updated = true; ARROW_ASSIGN_OR_RAISE(advanced, Advance()); } while (advanced); return updated; } - Status Push(const std::shared_ptr& rb) { - if (!nullable_by_key_ && key_col_index_.size() == 1 && - rb->column_data(key_col_index_[0])->GetNullCount() > 0) { - return Status::Invalid("AsofJoin does not allow unexpected null by-key values"); + void Rehash() { + MemoStore new_memo; + for (const auto& entry : memo_.entries_) { + const auto& e = entry.second; + new_memo.Store(e.batch, e.row, e.time, GetLatestKey(e.batch.get(), e.row)); } + memo_ = new_memo; + } + + Status Push(const std::shared_ptr& rb) { if (rb->num_rows() > 0) { queue_.Push(rb); } else { @@ -459,8 +477,8 @@ class InputState { mutable KeyHasher* key_hasher_; // True if hashing is mandatory bool must_hash_; - // True if null by-key values are expected - bool nullable_by_key_; + // True if by-key values may be rehashed + bool may_rehash_; // Index of the latest row reference within; if >0 then queue_ cannot be empty // Must be < queue_.front()->num_rows() if queue_ is non-empty row_index_t latest_ref_row_ = 0; @@ -817,14 +835,14 @@ class AsofJoinNode : public ExecNode { const std::vector>& indices_of_by_key, OnType tolerance, std::shared_ptr output_schema, std::vector> key_hashers, bool must_hash, - bool nullable_by_key); + bool may_rehash); Status Init() override { auto inputs = this->inputs(); for (size_t i = 0; i < inputs.size(); i++) { RETURN_NOT_OK(key_hashers_[i]->Init(plan()->exec_context(), output_schema())); state_.push_back(::arrow::internal::make_unique( - must_hash_, nullable_by_key_, key_hashers_[i].get(), inputs[i]->output_schema(), + must_hash_, may_rehash_, key_hashers_[i].get(), inputs[i]->output_schema(), indices_of_on_key_[i], indices_of_by_key_[i])); } @@ -1036,14 +1054,15 @@ class AsofJoinNode : public ExecNode { ::arrow::internal::make_unique(indices_of_by_key[i])); } bool must_hash = - n_by != 1 || - !is_primitive( - inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id()); - bool nullable_by_key = join_options.nullable_by_key; + n_by > 1 || + (n_by == 1 && + !is_primitive( + inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id())); + bool may_rehash = n_by == 1 && !must_hash; return plan->EmplaceNode( plan, inputs, std::move(input_labels), std::move(indices_of_on_key), std::move(indices_of_by_key), time_value(join_options.tolerance), - std::move(output_schema), std::move(key_hashers), must_hash, nullable_by_key); + std::move(output_schema), std::move(key_hashers), must_hash, may_rehash); } const char* kind_name() const override { return "AsofJoinNode"; } @@ -1103,7 +1122,7 @@ class AsofJoinNode : public ExecNode { std::vector> indices_of_by_key_; std::vector> key_hashers_; bool must_hash_; - bool nullable_by_key_; + bool may_rehash_; // InputStates // Each input state correponds to an input table std::vector> state_; @@ -1126,7 +1145,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, const std::vector>& indices_of_by_key, OnType tolerance, std::shared_ptr output_schema, std::vector> key_hashers, - bool must_hash, bool nullable_by_key) + bool must_hash, bool may_rehash) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), @@ -1134,7 +1153,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, indices_of_by_key_(std::move(indices_of_by_key)), key_hashers_(std::move(key_hashers)), must_hash_(must_hash), - nullable_by_key_(nullable_by_key), + may_rehash_(may_rehash), tolerance_(tolerance), process_(), process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 2de638535b768..82682abec9fd3 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -245,9 +245,9 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, void CheckRunOutput( \ const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, \ const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \ - const FieldRef time, by_key_type keys, const int64_t tolerance) { \ + const FieldRef time, by_key_type key, const int64_t tolerance) { \ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \ - AsofJoinNodeOptions(time, keys, tolerance)); \ + AsofJoinNodeOptions(time, {key}, tolerance)); \ } EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) @@ -299,7 +299,8 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema, int64_t tolerance, const std::string& expected_error_str) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", "key", tolerance), + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {"key"}, tolerance), expected_error_str); } @@ -321,25 +322,28 @@ void DoRunMissingKeysTest(const std::shared_ptr& l_schema, void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("invalid_time", "key", 0), + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("invalid_time", {"key"}, 0), "Bad join key on table : No match"); } void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", "invalid_key", 0), + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {"invalid_key"}, 0), "Bad join key on table : No match"); } void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, "key", 0), + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, {"key"}, 0), "Bad join key on table : No match"); } void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", FieldRef{0, 1}, 0), + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0), "Bad join key on table : No match"); } @@ -400,26 +404,10 @@ void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema, - AsofJoinNodeOptions("time", "key", 1000), + AsofJoinNodeOptions("time", {"key"}, 1000), "out-of-order on-key values"); } -void DoRunNullByKeyPlanTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema) { - AsofJoinNodeOptions join_options{"time", "key2", 1000}; - std::string expected_error_str = "unexpected null by-key values"; - int n_rows = 5; - auto l_str = GetTestBatchAsJsonString(n_rows, l_schema->num_fields()); - auto r_str = GetTestBatchAsJsonString(n_rows, r_schema->num_fields()); - ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {l_str})); - ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {r_str})); - ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true, true)); - ASSERT_OK_AND_ASSIGN(r_batches, MutateByKey(r_batches, "key", "key2", true, true)); - - return DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str, - /*then_run_plan=*/true); -} - struct BasicTestTypes { std::shared_ptr time, key, l_val, r0_val, r1_val; }; @@ -513,8 +501,7 @@ struct BasicTest { ASSERT_OK_AND_ASSIGN(exp_nokey_batches, MutateByKey(exp_nokey_batches, "key", "key2", true, true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, - AsofJoinNodeOptions("time", "key2", tolerance, - /*nullable_by_key=*/true)); + AsofJoinNodeOptions("time", {"key2"}, tolerance)); }); } static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); } @@ -1033,11 +1020,5 @@ TRACED_TEST(AsofJoinTest, TestUnorderedOnKey, { schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); }) -TRACED_TEST(AsofJoinTest, TestNullByKey, { - DoRunNullByKeyPlanTest( - schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), - schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); -}) - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 75aadacbbd02e..e0172bff7f762 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -397,29 +397,8 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, const FieldRef& by_key, int64_t tolerance, - bool nullable_by_key = false) - : on_key(std::move(on_key)), - by_key(), - tolerance(tolerance), - nullable_by_key(nullable_by_key) { - this->by_key.push_back(std::move(by_key)); - } - - AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance, - bool nullable_by_key = false) - : on_key(std::move(on_key)), - by_key(by_key), - tolerance(tolerance), - nullable_by_key(nullable_by_key) {} - - // resolves ambiguity between previous constructors when initializer list is given - AsofJoinNodeOptions(FieldRef on_key, std::initializer_list by_key, - int64_t tolerance, bool nullable_by_key = false) - : on_key(std::move(on_key)), - by_key(by_key), - tolerance(tolerance), - nullable_by_key(nullable_by_key) {} + AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} /// \brief "on" key for the join. /// @@ -438,10 +417,6 @@ class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { /// /// The tolerance is interpreted in the same units as the "on" key. int64_t tolerance; - /// \brief Whether the "by" key is nullable. - /// - /// Set to true if the "by" key is expected to take null values. - bool nullable_by_key; }; /// \brief Make a node which select top_k/bottom_k rows passed through it From a91212830bf275194545dc4a14f09e672c47a9b2 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 2 Sep 2022 17:19:40 -0400 Subject: [PATCH 24/26] requested fixes --- cpp/src/arrow/compute/exec/asof_join_node.cc | 15 ++++++++------- cpp/src/arrow/compute/exec/asof_join_node_test.cc | 1 - 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index d4e704ba623f2..2a6c377aee4d3 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -94,6 +94,11 @@ class ConcurrentQueue { cond_.notify_one(); } + void Clear() { + std::unique_lock lock(mutex_); + queue_ = {}; + } + util::optional TryPop() { // Try to pop the oldest value from the queue (or return nullopt if none) std::unique_lock lock(mutex_); @@ -801,7 +806,6 @@ class AsofJoinNode : public ExecNode { ExecBatch out_b(*out_rb); outputs_[0]->InputReceived(this, std::move(out_b)); } else { - StopProducing(); ErrorIfNotOk(result.status()); return; } @@ -813,8 +817,8 @@ class AsofJoinNode : public ExecNode { // It may happen here in cases where InputFinished was called before we were finished // producing results (so we didn't know the output size at that time) if (state_.at(0)->Finished()) { - StopProducing(); outputs_[0]->InputFinished(this, batches_produced_); + finished_.MarkFinished(); } } @@ -1083,7 +1087,6 @@ class AsofJoinNode : public ExecNode { } void ErrorReceived(ExecNode* input, Status error) override { outputs_[0]->ErrorReceived(this, std::move(error)); - StopProducing(); } void InputFinished(ExecNode* input, int total_batches) override { { @@ -1109,10 +1112,8 @@ class AsofJoinNode : public ExecNode { StopProducing(); } void StopProducing() override { - // avoid finishing twice, to prevent "Plan was destroyed before finishing" error - if (finished_.state() == FutureState::PENDING) { - finished_.MarkFinished(); - } + process_.Clear(); + process_.Push(false); } arrow::Future<> finished() override { return finished_; } diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 82682abec9fd3..48d1ae6410b15 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -585,7 +585,6 @@ struct BasicTest { auto end_time = std::chrono::system_clock::now(); std::chrono::duration diff = end_time - start_time; if (diff.count() > 2) { - std::cerr << "AsofJoin test reached time limit at iteration " << i << std::endl; // this normally happens on slow CI systems, but is fine break; } From a371f8c56754e5e7738f36c8afa8aef6b9b45ecc Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sat, 3 Sep 2022 04:26:11 -0400 Subject: [PATCH 25/26] fix queue initialization --- cpp/src/arrow/compute/exec/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 2a6c377aee4d3..980d8f75b18ae 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -96,7 +96,7 @@ class ConcurrentQueue { void Clear() { std::unique_lock lock(mutex_); - queue_ = {}; + queue_ = std::queue(); } util::optional TryPop() { From 3cd3042e1aec6ffa05e40aaf80fae412b56b5bdc Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 8 Sep 2022 12:32:16 -0400 Subject: [PATCH 26/26] add doc --- cpp/src/arrow/compute/exec/asof_join_node.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 980d8f75b18ae..869456a577531 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -712,6 +712,11 @@ class CompositeReferenceTable { } }; +// TODO: Currently, AsofJoinNode uses 64-bit hashing which leads to a non-negligible +// probability of collision, which can cause incorrect results when many different by-key +// values are processed. Thus, AsofJoinNode is currently limited to about 100k by-keys for +// guaranteeing this probability is below 1 in a billion. The fix is 128-bit hashing. +// See ARROW-17653 class AsofJoinNode : public ExecNode { // Advances the RHS as far as possible to be up to date for the current LHS timestamp Result UpdateRhs() {