diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index 11b36d9af27c9..6520e4ae4a3f3 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -66,6 +66,7 @@ class ARROW_EXPORT HashJoinSchema { SchemaProjectionMaps proj_maps[2]; private: + static bool IsTypeSupported(const DataType& type); static Result> VectorDiff(const Schema& schema, const std::vector& a, const std::vector& b); diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 583ac9a14685b..4bccb761070f4 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -34,6 +34,15 @@ using internal::checked_cast; namespace compute { +// Check if a type is supported in a join (as either a key or non-key column) +bool HashJoinSchema::IsTypeSupported(const DataType& type) { + const Type::type id = type.id(); + if (id == Type::DICTIONARY) { + return IsTypeSupported(*checked_cast(type).value_type()); + } + return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id); +} + Result> HashJoinSchema::VectorDiff(const Schema& schema, const std::vector& a, const std::vector& b) { @@ -141,8 +150,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc // 2. Same number of key fields on left and right // 3. At least one key field // 4. Equal data types for corresponding key fields - // 5. Dictionary type is not supported in a key field - // 6. Some other data types may not be allowed in a key field + // 5. Some data types may not be allowed in a key field or non-key field // if (left_keys.size() != right_keys.size()) { return Status::Invalid("Different number of key fields on left (", left_keys.size(), @@ -164,11 +172,8 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc const FieldPath& match = result.ValueUnsafe(); const std::shared_ptr& type = (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type(); - if ((type->id() != Type::BOOL && !is_fixed_width(type->id()) && - !is_binary_like(type->id())) || - is_large_binary_like(type->id())) { - return Status::Invalid("Data type ", type->ToString(), - " is not supported in join key field"); + if (!IsTypeSupported(*type)) { + return Status::Invalid("Data type ", *type, " is not supported in join key field"); } } for (size_t i = 0; i < left_keys.size(); ++i) { @@ -185,6 +190,20 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc right_ref.ToString(), " of type ", right_type->ToString()); } } + for (const auto& field : left_schema.fields()) { + const auto& type = *field->type(); + if (!IsTypeSupported(type)) { + return Status::Invalid("Data type ", type, + " is not supported in join non-key field"); + } + } + for (const auto& field : right_schema.fields()) { + const auto& type = *field->type(); + if (!IsTypeSupported(type)) { + return Status::Invalid("Data type ", type, + " is not supported in join non-key field"); + } + } // Check for output fields: // 1. Output field refs must match exactly one input field diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index d20b456fec513..9afddf3c5dc2f 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -1656,5 +1656,38 @@ TEST(HashJoin, DictNegative) { } } +TEST(HashJoin, UnsupportedTypes) { + // ARROW-14519 + const bool parallel = false; + const bool slow = false; + + auto l_schema = schema({field("l_i32", int32()), field("l_list", list(int32()))}); + auto l_schema_nolist = schema({field("l_i32", int32())}); + auto r_schema = schema({field("r_i32", int32()), field("r_list", list(int32()))}); + auto r_schema_nolist = schema({field("r_i32", int32())}); + + std::vector, std::shared_ptr>> cases{ + {l_schema, r_schema}, {l_schema_nolist, r_schema}, {l_schema, r_schema_nolist}}; + std::vector l_keys{{"l_i32"}}; + std::vector r_keys{{"r_i32"}}; + + for (const auto& schemas : cases) { + BatchesWithSchema l_batches = GenerateBatchesFromString(schemas.first, {R"([])"}); + BatchesWithSchema r_batches = GenerateBatchesFromString(schemas.second, {R"([])"}); + + ExecContext exec_ctx; + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); + + HashJoinNodeOptions join_options{JoinType::LEFT_SEMI, l_keys, r_keys}; + Declaration join{"hashjoin", join_options}; + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, slow)}}); + join.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, slow)}}); + + ASSERT_RAISES(Invalid, join.AddToPlan(plan.get())); + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/schema_util.h b/cpp/src/arrow/compute/exec/schema_util.h index 33f42701ff586..279cbb806db32 100644 --- a/cpp/src/arrow/compute/exec/schema_util.h +++ b/cpp/src/arrow/compute/exec/schema_util.h @@ -62,7 +62,7 @@ class SchemaProjectionMaps { const std::vector& projection_handles, const std::vector*>& projections) { ARROW_DCHECK(projection_handles.size() == projections.size()); - RegisterSchema(full_schema_handle, schema); + ARROW_RETURN_NOT_OK(RegisterSchema(full_schema_handle, schema)); for (size_t i = 0; i < projections.size(); ++i) { ARROW_RETURN_NOT_OK( RegisterProjectedSchema(projection_handles[i], *(projections[i]), schema)); @@ -76,11 +76,6 @@ class SchemaProjectionMaps { return static_cast(schemas_[id].second.size()); } - const KeyEncoder::KeyColumnMetadata& column_metadata(ProjectionIdEnum schema_handle, - int field_id) const { - return field(schema_handle, field_id).column_metadata; - } - const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const { return field(schema_handle, field_id).field_name; } @@ -105,10 +100,9 @@ class SchemaProjectionMaps { int field_path; std::string field_name; std::shared_ptr data_type; - KeyEncoder::KeyColumnMetadata column_metadata; }; - void RegisterSchema(ProjectionIdEnum handle, const Schema& schema) { + Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) { std::vector out_fields; const FieldVector& in_fields = schema.fields(); out_fields.resize(in_fields.size()); @@ -118,9 +112,9 @@ class SchemaProjectionMaps { out_fields[i].field_path = static_cast(i); out_fields[i].field_name = name; out_fields[i].data_type = type; - out_fields[i].column_metadata = ColumnMetadataFromDataType(type); } schemas_.push_back(std::make_pair(handle, out_fields)); + return Status::OK(); } Status RegisterProjectedSchema(ProjectionIdEnum handle, @@ -137,7 +131,6 @@ class SchemaProjectionMaps { out_fields[i].field_path = match[0]; out_fields[i].field_name = name; out_fields[i].data_type = type; - out_fields[i].column_metadata = ColumnMetadataFromDataType(type); } schemas_.push_back(std::make_pair(handle, out_fields)); return Status::OK(); @@ -153,25 +146,6 @@ class SchemaProjectionMaps { } } - KeyEncoder::KeyColumnMetadata ColumnMetadataFromDataType( - const std::shared_ptr& type) { - if (type->id() == Type::DICTIONARY) { - auto bit_width = checked_cast(*type).bit_width(); - ARROW_DCHECK(bit_width % 8 == 0); - return KeyEncoder::KeyColumnMetadata(true, bit_width / 8); - } else if (type->id() == Type::BOOL) { - return KeyEncoder::KeyColumnMetadata(true, 0); - } else if (is_fixed_width(type->id())) { - return KeyEncoder::KeyColumnMetadata( - true, checked_cast(*type).bit_width() / 8); - } else if (is_binary_like(type->id())) { - return KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t)); - } else { - ARROW_DCHECK(false); - return KeyEncoder::KeyColumnMetadata(true, 0); - } - } - int schema_id(ProjectionIdEnum schema_handle) const { for (size_t i = 0; i < schemas_.size(); ++i) { if (schemas_[i].first == schema_handle) {