From 412da89146f2366925abda86d34c49d25a78b294 Mon Sep 17 00:00:00 2001
From: David Li
Date: Sat, 6 Nov 2021 09:48:09 -0400
Subject: [PATCH] ARROW-14519: [C++] Properly error if joining on unsupported
type
Instead of DCHECK, return a NotImplemented.
Closes #11625 from lidavidm/arrow-14519
Authored-by: David Li
Signed-off-by: David Li
---
cpp/src/arrow/compute/exec/hash_join.h | 1 +
cpp/src/arrow/compute/exec/hash_join_node.cc | 33 +++++++++++++++----
.../arrow/compute/exec/hash_join_node_test.cc | 33 +++++++++++++++++++
cpp/src/arrow/compute/exec/schema_util.h | 32 ++----------------
4 files changed, 63 insertions(+), 36 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h
index 11b36d9af27c9..6520e4ae4a3f3 100644
--- a/cpp/src/arrow/compute/exec/hash_join.h
+++ b/cpp/src/arrow/compute/exec/hash_join.h
@@ -66,6 +66,7 @@ class ARROW_EXPORT HashJoinSchema {
SchemaProjectionMaps proj_maps[2];
private:
+ static bool IsTypeSupported(const DataType& type);
static Result> VectorDiff(const Schema& schema,
const std::vector& a,
const std::vector& b);
diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc
index 583ac9a14685b..4bccb761070f4 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node.cc
@@ -34,6 +34,15 @@ using internal::checked_cast;
namespace compute {
+// Check if a type is supported in a join (as either a key or non-key column)
+bool HashJoinSchema::IsTypeSupported(const DataType& type) {
+ const Type::type id = type.id();
+ if (id == Type::DICTIONARY) {
+ return IsTypeSupported(*checked_cast(type).value_type());
+ }
+ return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id);
+}
+
Result> HashJoinSchema::VectorDiff(const Schema& schema,
const std::vector& a,
const std::vector& b) {
@@ -141,8 +150,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
// 2. Same number of key fields on left and right
// 3. At least one key field
// 4. Equal data types for corresponding key fields
- // 5. Dictionary type is not supported in a key field
- // 6. Some other data types may not be allowed in a key field
+ // 5. Some data types may not be allowed in a key field or non-key field
//
if (left_keys.size() != right_keys.size()) {
return Status::Invalid("Different number of key fields on left (", left_keys.size(),
@@ -164,11 +172,8 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
const FieldPath& match = result.ValueUnsafe();
const std::shared_ptr& type =
(left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type();
- if ((type->id() != Type::BOOL && !is_fixed_width(type->id()) &&
- !is_binary_like(type->id())) ||
- is_large_binary_like(type->id())) {
- return Status::Invalid("Data type ", type->ToString(),
- " is not supported in join key field");
+ if (!IsTypeSupported(*type)) {
+ return Status::Invalid("Data type ", *type, " is not supported in join key field");
}
}
for (size_t i = 0; i < left_keys.size(); ++i) {
@@ -185,6 +190,20 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
right_ref.ToString(), " of type ", right_type->ToString());
}
}
+ for (const auto& field : left_schema.fields()) {
+ const auto& type = *field->type();
+ if (!IsTypeSupported(type)) {
+ return Status::Invalid("Data type ", type,
+ " is not supported in join non-key field");
+ }
+ }
+ for (const auto& field : right_schema.fields()) {
+ const auto& type = *field->type();
+ if (!IsTypeSupported(type)) {
+ return Status::Invalid("Data type ", type,
+ " is not supported in join non-key field");
+ }
+ }
// Check for output fields:
// 1. Output field refs must match exactly one input field
diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
index d20b456fec513..9afddf3c5dc2f 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
@@ -1656,5 +1656,38 @@ TEST(HashJoin, DictNegative) {
}
}
+TEST(HashJoin, UnsupportedTypes) {
+ // ARROW-14519
+ const bool parallel = false;
+ const bool slow = false;
+
+ auto l_schema = schema({field("l_i32", int32()), field("l_list", list(int32()))});
+ auto l_schema_nolist = schema({field("l_i32", int32())});
+ auto r_schema = schema({field("r_i32", int32()), field("r_list", list(int32()))});
+ auto r_schema_nolist = schema({field("r_i32", int32())});
+
+ std::vector, std::shared_ptr>> cases{
+ {l_schema, r_schema}, {l_schema_nolist, r_schema}, {l_schema, r_schema_nolist}};
+ std::vector l_keys{{"l_i32"}};
+ std::vector r_keys{{"r_i32"}};
+
+ for (const auto& schemas : cases) {
+ BatchesWithSchema l_batches = GenerateBatchesFromString(schemas.first, {R"([])"});
+ BatchesWithSchema r_batches = GenerateBatchesFromString(schemas.second, {R"([])"});
+
+ ExecContext exec_ctx;
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx));
+
+ HashJoinNodeOptions join_options{JoinType::LEFT_SEMI, l_keys, r_keys};
+ Declaration join{"hashjoin", join_options};
+ join.inputs.emplace_back(Declaration{
+ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, slow)}});
+ join.inputs.emplace_back(Declaration{
+ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, slow)}});
+
+ ASSERT_RAISES(Invalid, join.AddToPlan(plan.get()));
+ }
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/schema_util.h b/cpp/src/arrow/compute/exec/schema_util.h
index 33f42701ff586..279cbb806db32 100644
--- a/cpp/src/arrow/compute/exec/schema_util.h
+++ b/cpp/src/arrow/compute/exec/schema_util.h
@@ -62,7 +62,7 @@ class SchemaProjectionMaps {
const std::vector& projection_handles,
const std::vector*>& projections) {
ARROW_DCHECK(projection_handles.size() == projections.size());
- RegisterSchema(full_schema_handle, schema);
+ ARROW_RETURN_NOT_OK(RegisterSchema(full_schema_handle, schema));
for (size_t i = 0; i < projections.size(); ++i) {
ARROW_RETURN_NOT_OK(
RegisterProjectedSchema(projection_handles[i], *(projections[i]), schema));
@@ -76,11 +76,6 @@ class SchemaProjectionMaps {
return static_cast(schemas_[id].second.size());
}
- const KeyEncoder::KeyColumnMetadata& column_metadata(ProjectionIdEnum schema_handle,
- int field_id) const {
- return field(schema_handle, field_id).column_metadata;
- }
-
const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const {
return field(schema_handle, field_id).field_name;
}
@@ -105,10 +100,9 @@ class SchemaProjectionMaps {
int field_path;
std::string field_name;
std::shared_ptr data_type;
- KeyEncoder::KeyColumnMetadata column_metadata;
};
- void RegisterSchema(ProjectionIdEnum handle, const Schema& schema) {
+ Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) {
std::vector out_fields;
const FieldVector& in_fields = schema.fields();
out_fields.resize(in_fields.size());
@@ -118,9 +112,9 @@ class SchemaProjectionMaps {
out_fields[i].field_path = static_cast(i);
out_fields[i].field_name = name;
out_fields[i].data_type = type;
- out_fields[i].column_metadata = ColumnMetadataFromDataType(type);
}
schemas_.push_back(std::make_pair(handle, out_fields));
+ return Status::OK();
}
Status RegisterProjectedSchema(ProjectionIdEnum handle,
@@ -137,7 +131,6 @@ class SchemaProjectionMaps {
out_fields[i].field_path = match[0];
out_fields[i].field_name = name;
out_fields[i].data_type = type;
- out_fields[i].column_metadata = ColumnMetadataFromDataType(type);
}
schemas_.push_back(std::make_pair(handle, out_fields));
return Status::OK();
@@ -153,25 +146,6 @@ class SchemaProjectionMaps {
}
}
- KeyEncoder::KeyColumnMetadata ColumnMetadataFromDataType(
- const std::shared_ptr& type) {
- if (type->id() == Type::DICTIONARY) {
- auto bit_width = checked_cast(*type).bit_width();
- ARROW_DCHECK(bit_width % 8 == 0);
- return KeyEncoder::KeyColumnMetadata(true, bit_width / 8);
- } else if (type->id() == Type::BOOL) {
- return KeyEncoder::KeyColumnMetadata(true, 0);
- } else if (is_fixed_width(type->id())) {
- return KeyEncoder::KeyColumnMetadata(
- true, checked_cast(*type).bit_width() / 8);
- } else if (is_binary_like(type->id())) {
- return KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t));
- } else {
- ARROW_DCHECK(false);
- return KeyEncoder::KeyColumnMetadata(true, 0);
- }
- }
-
int schema_id(ProjectionIdEnum schema_handle) const {
for (size_t i = 0; i < schemas_.size(); ++i) {
if (schemas_[i].first == schema_handle) {