Skip to content

Commit

Permalink
ARROW-14519: [C++] Properly error if joining on unsupported type
Browse files Browse the repository at this point in the history
Instead of DCHECK, return a NotImplemented.

Closes #11625 from lidavidm/arrow-14519

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
lidavidm committed Nov 6, 2021
1 parent ae808e0 commit 412da89
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 36 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/exec/hash_join.h
Expand Up @@ -66,6 +66,7 @@ class ARROW_EXPORT HashJoinSchema {
SchemaProjectionMaps<HashJoinProjection> proj_maps[2];

private:
static bool IsTypeSupported(const DataType& type);
static Result<std::vector<FieldRef>> VectorDiff(const Schema& schema,
const std::vector<FieldRef>& a,
const std::vector<FieldRef>& b);
Expand Down
33 changes: 26 additions & 7 deletions cpp/src/arrow/compute/exec/hash_join_node.cc
Expand Up @@ -34,6 +34,15 @@ using internal::checked_cast;

namespace compute {

// Check if a type is supported in a join (as either a key or non-key column)
bool HashJoinSchema::IsTypeSupported(const DataType& type) {
const Type::type id = type.id();
if (id == Type::DICTIONARY) {
return IsTypeSupported(*checked_cast<const DictionaryType&>(type).value_type());
}
return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id);
}

Result<std::vector<FieldRef>> HashJoinSchema::VectorDiff(const Schema& schema,
const std::vector<FieldRef>& a,
const std::vector<FieldRef>& b) {
Expand Down Expand Up @@ -141,8 +150,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
// 2. Same number of key fields on left and right
// 3. At least one key field
// 4. Equal data types for corresponding key fields
// 5. Dictionary type is not supported in a key field
// 6. Some other data types may not be allowed in a key field
// 5. Some data types may not be allowed in a key field or non-key field
//
if (left_keys.size() != right_keys.size()) {
return Status::Invalid("Different number of key fields on left (", left_keys.size(),
Expand All @@ -164,11 +172,8 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
const FieldPath& match = result.ValueUnsafe();
const std::shared_ptr<DataType>& type =
(left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type();
if ((type->id() != Type::BOOL && !is_fixed_width(type->id()) &&
!is_binary_like(type->id())) ||
is_large_binary_like(type->id())) {
return Status::Invalid("Data type ", type->ToString(),
" is not supported in join key field");
if (!IsTypeSupported(*type)) {
return Status::Invalid("Data type ", *type, " is not supported in join key field");
}
}
for (size_t i = 0; i < left_keys.size(); ++i) {
Expand All @@ -185,6 +190,20 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
right_ref.ToString(), " of type ", right_type->ToString());
}
}
for (const auto& field : left_schema.fields()) {
const auto& type = *field->type();
if (!IsTypeSupported(type)) {
return Status::Invalid("Data type ", type,
" is not supported in join non-key field");
}
}
for (const auto& field : right_schema.fields()) {
const auto& type = *field->type();
if (!IsTypeSupported(type)) {
return Status::Invalid("Data type ", type,
" is not supported in join non-key field");
}
}

// Check for output fields:
// 1. Output field refs must match exactly one input field
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/arrow/compute/exec/hash_join_node_test.cc
Expand Up @@ -1656,5 +1656,38 @@ TEST(HashJoin, DictNegative) {
}
}

TEST(HashJoin, UnsupportedTypes) {
// ARROW-14519
const bool parallel = false;
const bool slow = false;

auto l_schema = schema({field("l_i32", int32()), field("l_list", list(int32()))});
auto l_schema_nolist = schema({field("l_i32", int32())});
auto r_schema = schema({field("r_i32", int32()), field("r_list", list(int32()))});
auto r_schema_nolist = schema({field("r_i32", int32())});

std::vector<std::pair<std::shared_ptr<Schema>, std::shared_ptr<Schema>>> cases{
{l_schema, r_schema}, {l_schema_nolist, r_schema}, {l_schema, r_schema_nolist}};
std::vector<FieldRef> l_keys{{"l_i32"}};
std::vector<FieldRef> r_keys{{"r_i32"}};

for (const auto& schemas : cases) {
BatchesWithSchema l_batches = GenerateBatchesFromString(schemas.first, {R"([])"});
BatchesWithSchema r_batches = GenerateBatchesFromString(schemas.second, {R"([])"});

ExecContext exec_ctx;
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx));

HashJoinNodeOptions join_options{JoinType::LEFT_SEMI, l_keys, r_keys};
Declaration join{"hashjoin", join_options};
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, slow)}});
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, slow)}});

ASSERT_RAISES(Invalid, join.AddToPlan(plan.get()));
}
}

} // namespace compute
} // namespace arrow
32 changes: 3 additions & 29 deletions cpp/src/arrow/compute/exec/schema_util.h
Expand Up @@ -62,7 +62,7 @@ class SchemaProjectionMaps {
const std::vector<ProjectionIdEnum>& projection_handles,
const std::vector<const std::vector<FieldRef>*>& projections) {
ARROW_DCHECK(projection_handles.size() == projections.size());
RegisterSchema(full_schema_handle, schema);
ARROW_RETURN_NOT_OK(RegisterSchema(full_schema_handle, schema));
for (size_t i = 0; i < projections.size(); ++i) {
ARROW_RETURN_NOT_OK(
RegisterProjectedSchema(projection_handles[i], *(projections[i]), schema));
Expand All @@ -76,11 +76,6 @@ class SchemaProjectionMaps {
return static_cast<int>(schemas_[id].second.size());
}

const KeyEncoder::KeyColumnMetadata& column_metadata(ProjectionIdEnum schema_handle,
int field_id) const {
return field(schema_handle, field_id).column_metadata;
}

const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const {
return field(schema_handle, field_id).field_name;
}
Expand All @@ -105,10 +100,9 @@ class SchemaProjectionMaps {
int field_path;
std::string field_name;
std::shared_ptr<DataType> data_type;
KeyEncoder::KeyColumnMetadata column_metadata;
};

void RegisterSchema(ProjectionIdEnum handle, const Schema& schema) {
Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) {
std::vector<FieldInfo> out_fields;
const FieldVector& in_fields = schema.fields();
out_fields.resize(in_fields.size());
Expand All @@ -118,9 +112,9 @@ class SchemaProjectionMaps {
out_fields[i].field_path = static_cast<int>(i);
out_fields[i].field_name = name;
out_fields[i].data_type = type;
out_fields[i].column_metadata = ColumnMetadataFromDataType(type);
}
schemas_.push_back(std::make_pair(handle, out_fields));
return Status::OK();
}

Status RegisterProjectedSchema(ProjectionIdEnum handle,
Expand All @@ -137,7 +131,6 @@ class SchemaProjectionMaps {
out_fields[i].field_path = match[0];
out_fields[i].field_name = name;
out_fields[i].data_type = type;
out_fields[i].column_metadata = ColumnMetadataFromDataType(type);
}
schemas_.push_back(std::make_pair(handle, out_fields));
return Status::OK();
Expand All @@ -153,25 +146,6 @@ class SchemaProjectionMaps {
}
}

KeyEncoder::KeyColumnMetadata ColumnMetadataFromDataType(
const std::shared_ptr<DataType>& type) {
if (type->id() == Type::DICTIONARY) {
auto bit_width = checked_cast<const FixedWidthType&>(*type).bit_width();
ARROW_DCHECK(bit_width % 8 == 0);
return KeyEncoder::KeyColumnMetadata(true, bit_width / 8);
} else if (type->id() == Type::BOOL) {
return KeyEncoder::KeyColumnMetadata(true, 0);
} else if (is_fixed_width(type->id())) {
return KeyEncoder::KeyColumnMetadata(
true, checked_cast<const FixedWidthType&>(*type).bit_width() / 8);
} else if (is_binary_like(type->id())) {
return KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t));
} else {
ARROW_DCHECK(false);
return KeyEncoder::KeyColumnMetadata(true, 0);
}
}

int schema_id(ProjectionIdEnum schema_handle) const {
for (size_t i = 0; i < schemas_.size(); ++i) {
if (schemas_[i].first == schema_handle) {
Expand Down

0 comments on commit 412da89

Please sign in to comment.