diff --git a/cpp/src/arrow/dataset/discovery.cc b/cpp/src/arrow/dataset/discovery.cc index 0f9d479b9d6..8b12f3ea815 100644 --- a/cpp/src/arrow/dataset/discovery.cc +++ b/cpp/src/arrow/dataset/discovery.cc @@ -43,7 +43,7 @@ Result> DatasetFactory::Inspect(InspectOptions options) return arrow::schema({}); } - return UnifySchemas(schemas); + return UnifySchemas(schemas, options.field_merge_options); } Result> DatasetFactory::Finish() { diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 40c02051955..382b23e4caa 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -58,6 +58,10 @@ struct InspectOptions { /// `kInspectAllFragments`. A value of `0` disables inspection of fragments /// altogether so only the partitioning schema will be inspected. int fragments = 1; + + /// Control how to unify types. By default, types are merged strictly (the + /// type must match exactly, except nulls can be merged with other types). + Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults(); }; struct FinishOptions { diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index a51b3c09971..8842d084b69 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -120,6 +120,12 @@ TEST_F(MockDatasetFactoryTest, UnifySchemas) { ASSERT_RAISES(Invalid, factory_->Inspect()); // Return the individual schema for closer inspection should not fail. AssertInspectSchemas({schema({i32, f64}), schema({f64, i32_fake})}); + + MakeFactory({schema({field("num", int32())}), schema({field("num", float64())})}); + ASSERT_RAISES(Invalid, factory_->Inspect()); + InspectOptions permissive_options; + permissive_options.field_merge_options = Field::MergeOptions::Permissive(); + AssertInspect(schema({field("num", float64())}), permissive_options); } class FileSystemDatasetFactoryTest : public DatasetFactoryTest { @@ -473,6 +479,12 @@ TEST(UnionDatasetFactoryTest, ConflictingSchemas) { auto i32_schema = schema({i32}); ASSERT_OK_AND_ASSIGN(auto dataset, factory->Finish(i32_schema)); EXPECT_EQ(*dataset->schema(), *i32_schema); + + // The user decided to allow merging the types. + FinishOptions options; + options.inspect_options.field_merge_options = Field::MergeOptions::Permissive(); + ASSERT_OK_AND_ASSIGN(dataset, factory->Finish(options)); + EXPECT_EQ(*dataset->schema(), *schema({f64, i32})); } } // namespace dataset diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc index 7d7ad61bca5..48756f8f6fc 100644 --- a/cpp/src/arrow/table.cc +++ b/cpp/src/arrow/table.cc @@ -38,9 +38,15 @@ #include "arrow/type_fwd.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" +// Get ARROW_COMPUTE definition +#include "arrow/util/config.h" #include "arrow/util/logging.h" #include "arrow/util/vector.h" +#ifdef ARROW_COMPUTE +#include "arrow/compute/cast.h" +#endif + namespace arrow { using internal::checked_cast; @@ -504,9 +510,24 @@ Result> PromoteTableToSchema(const std::shared_ptr continue; } +#ifdef ARROW_COMPUTE + if (!compute::CanCast(*current_field->type(), *field->type())) { + return Status::Invalid("Unable to promote field ", field->name(), + ": incompatible types: ", field->type()->ToString(), " vs ", + current_field->type()->ToString()); + } + compute::ExecContext ctx(pool); + auto options = compute::CastOptions::Safe(); + ARROW_ASSIGN_OR_RAISE(auto casted, compute::Cast(table->column(field_index), + field->type(), options, &ctx)); + columns.push_back(casted.chunked_array()); +#else return Status::Invalid("Unable to promote field ", field->name(), ": incompatible types: ", field->type()->ToString(), " vs ", - current_field->type()->ToString()); + current_field->type()->ToString(), + " (Arrow must be built with ARROW_COMPUTE " + "in order to cast incompatible types)"); +#endif } auto unseen_field_iter = std::find(fields_seen.begin(), fields_seen.end(), false); diff --git a/cpp/src/arrow/table.h b/cpp/src/arrow/table.h index 1d6cdd56765..f23756c4849 100644 --- a/cpp/src/arrow/table.h +++ b/cpp/src/arrow/table.h @@ -293,14 +293,18 @@ Result> ConcatenateTables( /// \brief Promotes a table to conform to the given schema. /// -/// If a field in the schema does not have a corresponding column in the -/// table, a column of nulls will be added to the resulting table. -/// If the corresponding column is of type Null, it will be promoted to -/// the type specified by schema, with null values filled. +/// If a field in the schema does not have a corresponding column in +/// the table, a column of nulls will be added to the resulting table. +/// If the corresponding column is of type Null, it will be promoted +/// to the type specified by schema, with null values filled. If Arrow +/// was built with ARROW_COMPUTE, then the column will be casted to +/// the type specified by the schema. +/// /// Returns an error: /// - if the corresponding column's type is not compatible with the /// schema. /// - if there is a column in the table that does not exist in the schema. +/// - if the cast fails or casting would be required but is not available. /// /// \param[in] table the input Table /// \param[in] schema the target schema to promote to diff --git a/cpp/src/arrow/table_test.cc b/cpp/src/arrow/table_test.cc index 3f6589fdf94..c4dddacb28d 100644 --- a/cpp/src/arrow/table_test.cc +++ b/cpp/src/arrow/table_test.cc @@ -34,6 +34,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/type.h" +#include "arrow/util/config.h" #include "arrow/util/key_value_metadata.h" namespace arrow { @@ -417,8 +418,9 @@ TEST_F(TestPromoteTableToSchema, IncompatibleTypes) { // Invalid promotion: int32 to null. ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", null())}))); - // Invalid promotion: int32 to uint32. - ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", uint32())}))); + // Invalid promotion: int32 to list. + ASSERT_RAISES(Invalid, + PromoteTableToSchema(table, schema({field("field", list(int32()))}))); } TEST_F(TestPromoteTableToSchema, IncompatibleNullity) { @@ -517,6 +519,42 @@ TEST_F(ConcatenateTablesWithPromotionTest, Simple) { AssertTablesEqualUnorderedFields(*expected, *result); } +TEST_F(ConcatenateTablesWithPromotionTest, Unify) { + auto t1 = TableFromJSON(schema({field("f0", int32())}), {"[[0], [1]]"}); + auto t2 = TableFromJSON(schema({field("f0", int64())}), {"[[2], [3]]"}); + auto t3 = TableFromJSON(schema({field("f0", null())}), {"[[null], [null]]"}); + + auto expected_int64 = + TableFromJSON(schema({field("f0", int64())}), {"[[0], [1], [2], [3]]"}); + auto expected_null = + TableFromJSON(schema({field("f0", int32())}), {"[[0], [1], [null], [null]]"}); + + ConcatenateTablesOptions options; + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Schema at index 1 was different"), + ConcatenateTables({t1, t2}, options)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Schema at index 1 was different"), + ConcatenateTables({t1, t3}, options)); + + options.unify_schemas = true; + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Field f0 has incompatible types"), + ConcatenateTables({t1, t2}, options)); + ASSERT_OK_AND_ASSIGN(auto actual, ConcatenateTables({t1, t3}, options)); + AssertTablesEqual(*expected_null, *actual, /*same_chunk_layout=*/false); + + options.field_merge_options.promote_numeric_width = true; +#ifdef ARROW_COMPUTE + ASSERT_OK_AND_ASSIGN(actual, ConcatenateTables({t1, t2}, options)); + AssertTablesEqual(*expected_int64, *actual, /*same_chunk_layout=*/false); +#else + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("must be built with ARROW_COMPUTE"), + ConcatenateTables({t1, t2}, options)); +#endif +} + TEST_F(TestTable, Slice) { const int64_t length = 10; diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 2a382662497..333eb0ac39c 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -36,6 +36,7 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/decimal.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/key_value_metadata.h" @@ -47,6 +48,8 @@ namespace arrow { +using internal::checked_cast; + constexpr Type::type NullType::type_id; constexpr Type::type ListType::type_id; constexpr Type::type LargeListType::type_id; @@ -216,27 +219,6 @@ std::shared_ptr GetPhysicalType(const std::shared_ptr& real_ return std::move(visitor.result); } -namespace { - -using internal::checked_cast; - -// Merges `existing` and `other` if one of them is of NullType, otherwise -// returns nullptr. -// - if `other` if of NullType or is nullable, the unified field will be nullable. -// - if `existing` is of NullType but other is not, the unified field will -// have `other`'s type and will be nullable -std::shared_ptr MaybePromoteNullTypes(const Field& existing, const Field& other) { - if (existing.type()->id() != Type::NA && other.type()->id() != Type::NA) { - return nullptr; - } - if (existing.type()->id() == Type::NA) { - return other.WithNullable(true)->WithMetadata(existing.metadata()); - } - // `other` must be null. - return existing.WithNullable(true); -} -} // namespace - Field::~Field() {} bool Field::HasMetadata() const { @@ -275,6 +257,431 @@ std::shared_ptr Field::WithNullable(const bool nullable) const { return std::make_shared(name_, type_, nullable, metadata_); } +Field::MergeOptions Field::MergeOptions::Permissive() { + MergeOptions options = Defaults(); + options.promote_nullability = true; + options.promote_decimal = true; + options.promote_decimal_float = true; + options.promote_integer_decimal = true; + options.promote_integer_float = true; + options.promote_integer_sign = true; + options.promote_numeric_width = true; + options.promote_binary = true; + options.promote_date = true; + options.promote_duration = true; + options.promote_time = true; + options.promote_timestamp = true; + options.promote_dictionary = true; + options.promote_dictionary_ordered = false; + options.promote_large = true; + options.promote_nested = true; + return options; +} + +std::string Field::MergeOptions::ToString() const { + std::stringstream ss; + ss << "MergeOptions{"; + ss << "promote_nullability=" << (promote_nullability ? "true" : "false"); + ss << ", promote_numeric_width=" << (promote_numeric_width ? "true" : "false"); + ss << ", promote_integer_float=" << (promote_integer_float ? "true" : "false"); + ss << ", promote_integer_decimal=" << (promote_integer_decimal ? "true" : "false"); + ss << ", promote_decimal_float=" << (promote_decimal_float ? "true" : "false"); + ss << ", promote_date=" << (promote_date ? "true" : "false"); + ss << ", promote_time=" << (promote_time ? "true" : "false"); + ss << ", promote_duration=" << (promote_duration ? "true" : "false"); + ss << ", promote_timestamp=" << (promote_timestamp ? "true" : "false"); + ss << ", promote_nested=" << (promote_nested ? "true" : "false"); + ss << ", promote_dictionary=" << (promote_dictionary ? "true" : "false"); + ss << ", promote_integer_sign=" << (promote_integer_sign ? "true" : "false"); + ss << ", promote_large=" << (promote_large ? "true" : "false"); + ss << ", promote_binary=" << (promote_binary ? "true" : "false"); + ss << '}'; + return ss.str(); +} + +namespace { +// Utilities for Field::MergeWith + +std::shared_ptr MakeSigned(const DataType& type) { + switch (type.id()) { + case Type::INT8: + case Type::UINT8: + return int8(); + case Type::INT16: + case Type::UINT16: + return int16(); + case Type::INT32: + case Type::UINT32: + return int32(); + case Type::INT64: + case Type::UINT64: + return int64(); + default: + DCHECK(false) << "unreachable"; + } + return std::shared_ptr(nullptr); +} +std::shared_ptr MakeBinary(const DataType& type) { + switch (type.id()) { + case Type::BINARY: + case Type::STRING: + return binary(); + case Type::LARGE_BINARY: + case Type::LARGE_STRING: + return large_binary(); + default: + DCHECK(false) << "unreachable"; + } + return std::shared_ptr(nullptr); +} +TimeUnit::type CommonTimeUnit(TimeUnit::type left, TimeUnit::type right) { + if (left == TimeUnit::NANO || right == TimeUnit::NANO) { + return TimeUnit::NANO; + } else if (left == TimeUnit::MICRO || right == TimeUnit::MICRO) { + return TimeUnit::MICRO; + } else if (left == TimeUnit::MILLI || right == TimeUnit::MILLI) { + return TimeUnit::MILLI; + } + return TimeUnit::SECOND; +} + +Result> MergeTypes(std::shared_ptr promoted_type, + std::shared_ptr other_type, + const Field::MergeOptions& options); + +// Merge two dictionary types, or else give an error. +Result> MergeDictionaryTypes( + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (!options.promote_dictionary_ordered && left.ordered() != right.ordered()) { + return Status::Invalid( + "Cannot merge ordered and unordered dictionary unless " + "promote_dictionary_ordered=true"); + } + Field::MergeOptions index_options = options; + index_options.promote_integer_sign = true; + index_options.promote_numeric_width = true; + ARROW_ASSIGN_OR_RAISE(auto indices, + MergeTypes(left.index_type(), right.index_type(), index_options)); + ARROW_ASSIGN_OR_RAISE(auto values, + MergeTypes(left.value_type(), right.value_type(), options)); + auto ordered = left.ordered() && right.ordered(); + if (indices && values) { + return dictionary(indices, values, ordered); + } else if (values) { + return Status::Invalid("Could not merge index types"); + } + return Status::Invalid("Could not merge value types"); +} + +// Merge temporal types based on options. Returns nullptr for non-temporal types. +Result> MaybeMergeTemporalTypes( + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { + if (options.promote_date) { + if (promoted_type->id() == Type::DATE32 && other_type->id() == Type::DATE64) { + return date64(); + } + if (promoted_type->id() == Type::DATE64 && other_type->id() == Type::DATE32) { + return date64(); + } + } + + if (options.promote_duration && promoted_type->id() == Type::DURATION && + other_type->id() == Type::DURATION) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + return duration(CommonTimeUnit(left.unit(), right.unit())); + } + + if (options.promote_time && is_time(promoted_type->id()) && is_time(other_type->id())) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + const auto unit = CommonTimeUnit(left.unit(), right.unit()); + if (unit == TimeUnit::MICRO || unit == TimeUnit::NANO) { + return time64(unit); + } + return time32(unit); + } + + if (options.promote_timestamp && promoted_type->id() == Type::TIMESTAMP && + other_type->id() == Type::TIMESTAMP) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (left.timezone().empty() ^ right.timezone().empty()) { + return Status::Invalid( + "Cannot merge timestamp with timezone and timestamp without timezone"); + } + if (left.timezone() != right.timezone()) { + return Status::Invalid("Cannot merge timestamps with differing timezones"); + } + return timestamp(CommonTimeUnit(left.unit(), right.unit()), left.timezone()); + } + + return nullptr; +} + +// Merge numeric types based on options. Returns nullptr for non-temporal types. +Result> MaybeMergeNumericTypes( + std::shared_ptr promoted_type, std::shared_ptr other_type, + const Field::MergeOptions& options) { + bool promoted = false; + if (options.promote_decimal_float) { + if (is_decimal(promoted_type->id()) && is_floating(other_type->id())) { + promoted_type = other_type; + promoted = true; + } else if (is_floating(promoted_type->id()) && is_decimal(other_type->id())) { + other_type = promoted_type; + promoted = true; + } + } + + if (options.promote_integer_decimal) { + if (is_integer(promoted_type->id()) && is_decimal(other_type->id())) { + promoted_type.swap(other_type); + } + + if (is_decimal(promoted_type->id()) && is_integer(other_type->id())) { + ARROW_ASSIGN_OR_RAISE(const int32_t precision, + MaxDecimalDigitsForInteger(other_type->id())); + ARROW_ASSIGN_OR_RAISE(other_type, + DecimalType::Make(promoted_type->id(), precision, 0)); + promoted = true; + } + } + + if (options.promote_decimal && is_decimal(promoted_type->id()) && + is_decimal(other_type->id())) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (!options.promote_numeric_width && left.bit_width() != right.bit_width()) { + return Status::Invalid( + "Cannot promote decimal128 to decimal256 without promote_numeric_width=true"); + } + const int32_t max_scale = std::max(left.scale(), right.scale()); + const int32_t common_precision = + std::max(left.precision() + max_scale - left.scale(), + right.precision() + max_scale - right.scale()); + if (left.id() == Type::DECIMAL256 || right.id() == Type::DECIMAL256 || + (options.promote_numeric_width && + common_precision > BasicDecimal128::kMaxPrecision)) { + return DecimalType::Make(Type::DECIMAL256, common_precision, max_scale); + } + return DecimalType::Make(Type::DECIMAL128, common_precision, max_scale); + } + + if (options.promote_integer_sign) { + if (is_unsigned_integer(promoted_type->id()) && is_signed_integer(other_type->id())) { + promoted = bit_width(other_type->id()) >= bit_width(promoted_type->id()); + promoted_type = MakeSigned(*promoted_type); + } else if (is_signed_integer(promoted_type->id()) && + is_unsigned_integer(other_type->id())) { + promoted = bit_width(promoted_type->id()) >= bit_width(other_type->id()); + other_type = MakeSigned(*other_type); + } + } + + if (options.promote_integer_float && + ((is_floating(promoted_type->id()) && is_integer(other_type->id())) || + (is_integer(promoted_type->id()) && is_floating(other_type->id())))) { + const int max_width = + std::max(bit_width(promoted_type->id()), bit_width(other_type->id())); + if (max_width >= 64) { + promoted_type = float64(); + } else if (max_width >= 32) { + promoted_type = float32(); + } else { + promoted_type = float16(); + } + promoted = true; + } + + if (options.promote_numeric_width) { + const int max_width = + std::max(bit_width(promoted_type->id()), bit_width(other_type->id())); + if (is_floating(promoted_type->id()) && is_floating(other_type->id())) { + if (max_width >= 64) { + return float64(); + } else if (max_width >= 32) { + return float32(); + } + return float16(); + } else if (is_signed_integer(promoted_type->id()) && + is_signed_integer(other_type->id())) { + if (max_width >= 64) { + return int64(); + } else if (max_width >= 32) { + return int32(); + } else if (max_width >= 16) { + return int16(); + } + return int8(); + } else if (is_unsigned_integer(promoted_type->id()) && + is_unsigned_integer(other_type->id())) { + if (max_width >= 64) { + return uint64(); + } else if (max_width >= 32) { + return uint32(); + } else if (max_width >= 16) { + return uint16(); + } + return uint8(); + } + } + + return promoted ? promoted_type : nullptr; +} + +Result> MergeTypes(std::shared_ptr promoted_type, + std::shared_ptr other_type, + const Field::MergeOptions& options) { + if (promoted_type->Equals(*other_type)) return promoted_type; + + bool promoted = false; + if (options.promote_nullability) { + if (promoted_type->id() == Type::NA) { + return other_type; + } else if (other_type->id() == Type::NA) { + return promoted_type; + } + } else if (promoted_type->id() == Type::NA || other_type->id() == Type::NA) { + return Status::Invalid("Cannot merge type with null unless promote_nullability=true"); + } + + if (options.promote_dictionary && is_dictionary(promoted_type->id()) && + is_dictionary(other_type->id())) { + return MergeDictionaryTypes(promoted_type, other_type, options); + } + + ARROW_ASSIGN_OR_RAISE(auto maybe_promoted, + MaybeMergeTemporalTypes(promoted_type, other_type, options)); + if (maybe_promoted) return maybe_promoted; + + ARROW_ASSIGN_OR_RAISE(maybe_promoted, + MaybeMergeNumericTypes(promoted_type, other_type, options)); + if (maybe_promoted) return maybe_promoted; + + if (options.promote_large) { + if (promoted_type->id() == Type::FIXED_SIZE_BINARY && + is_base_binary_like(other_type->id())) { + promoted_type = binary(); + promoted = other_type->id() == Type::BINARY; + } + if (other_type->id() == Type::FIXED_SIZE_BINARY && + is_base_binary_like(promoted_type->id())) { + other_type = binary(); + promoted = promoted_type->id() == Type::BINARY; + } + + if (promoted_type->id() == Type::FIXED_SIZE_LIST && + is_var_size_list(other_type->id())) { + promoted_type = + list(checked_cast(*promoted_type).value_field()); + promoted = other_type->Equals(*promoted_type); + } + if (other_type->id() == Type::FIXED_SIZE_LIST && + is_var_size_list(promoted_type->id())) { + other_type = list(checked_cast(*other_type).value_field()); + promoted = other_type->Equals(*promoted_type); + } + } + + if (options.promote_binary) { + if (promoted_type->id() == Type::FIXED_SIZE_BINARY && + other_type->id() == Type::FIXED_SIZE_BINARY) { + return binary(); + } + if (is_string(promoted_type->id()) && is_binary(other_type->id())) { + promoted_type = MakeBinary(*promoted_type); + promoted = + offset_bit_width(promoted_type->id()) == offset_bit_width(other_type->id()); + } else if (is_binary(promoted_type->id()) && is_string(other_type->id())) { + other_type = MakeBinary(*other_type); + promoted = + offset_bit_width(promoted_type->id()) == offset_bit_width(other_type->id()); + } + } + + if (options.promote_large) { + if ((promoted_type->id() == Type::STRING && other_type->id() == Type::LARGE_STRING) || + (promoted_type->id() == Type::LARGE_STRING && other_type->id() == Type::STRING)) { + return large_utf8(); + } else if ((promoted_type->id() == Type::BINARY && + other_type->id() == Type::LARGE_BINARY) || + (promoted_type->id() == Type::LARGE_BINARY && + other_type->id() == Type::BINARY)) { + return large_binary(); + } + if ((promoted_type->id() == Type::LIST && other_type->id() == Type::LARGE_LIST) || + (promoted_type->id() == Type::LARGE_LIST && other_type->id() == Type::LIST)) { + promoted_type = + large_list(checked_cast(*promoted_type).value_field()); + promoted = true; + } + } + + if (options.promote_nested) { + if ((promoted_type->id() == Type::LIST && other_type->id() == Type::LIST) || + (promoted_type->id() == Type::LARGE_LIST && + other_type->id() == Type::LARGE_LIST) || + (promoted_type->id() == Type::FIXED_SIZE_LIST && + other_type->id() == Type::FIXED_SIZE_LIST)) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + ARROW_ASSIGN_OR_RAISE( + auto value_field, + left.value_field()->MergeWith( + *right.value_field()->WithName(left.value_field()->name()), options)); + if (promoted_type->id() == Type::LIST) { + return list(std::move(value_field)); + } else if (promoted_type->id() == Type::LARGE_LIST) { + return large_list(std::move(value_field)); + } + const auto left_size = + checked_cast(*promoted_type).list_size(); + const auto right_size = + checked_cast(*other_type).list_size(); + if (left_size == right_size) { + return fixed_size_list(std::move(value_field), left_size); + } + return Status::Invalid("Cannot merge fixed_size_list of different sizes"); + } else if (promoted_type->id() == Type::MAP && other_type->id() == Type::MAP) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + // While we try to preserve nonstandard field names here, note that + // MapType comparisons ignore field name. See ARROW-7173, ARROW-14999. + ARROW_ASSIGN_OR_RAISE( + auto key_field, + left.key_field()->MergeWith( + *right.key_field()->WithName(left.key_field()->name()), options)); + ARROW_ASSIGN_OR_RAISE( + auto item_field, + left.item_field()->MergeWith( + *right.item_field()->WithName(left.item_field()->name()), options)); + return map(std::move(key_field), std::move(item_field), + /*keys_sorted=*/left.keys_sorted() && right.keys_sorted()); + } else if (promoted_type->id() == Type::STRUCT && other_type->id() == Type::STRUCT) { + SchemaBuilder builder(SchemaBuilder::CONFLICT_APPEND, options); + // Add the LHS fields. Duplicates will be preserved. + RETURN_NOT_OK(builder.AddFields(promoted_type->fields())); + + // Add the RHS fields. Duplicates will be merged, unless the field was + // already a duplicate, in which case we error (since we don't know which + // field to merge with). + builder.SetPolicy(SchemaBuilder::CONFLICT_MERGE); + RETURN_NOT_OK(builder.AddFields(other_type->fields())); + + ARROW_ASSIGN_OR_RAISE(auto schema, builder.Finish()); + return struct_(schema->fields()); + } + } + + return promoted ? promoted_type : nullptr; +} +} // namespace + Result> Field::MergeWith(const Field& other, MergeOptions options) const { if (name() != other.name()) { @@ -286,14 +693,27 @@ Result> Field::MergeWith(const Field& other, return Copy(); } - if (options.promote_nullability) { - if (type()->Equals(other.type())) { - return Copy()->WithNullable(nullable() || other.nullable()); + auto maybe_promoted_type = MergeTypes(type_, other.type(), options); + if (!maybe_promoted_type.ok()) { + return maybe_promoted_type.status().WithMessage( + "Unable to merge: Field ", name(), + " has incompatible types: ", type()->ToString(), " vs ", other.type()->ToString(), + ": ", maybe_promoted_type.status().message()); + } + auto promoted_type = move(maybe_promoted_type).MoveValueUnsafe(); + if (promoted_type) { + bool nullable = nullable_; + if (options.promote_nullability) { + nullable = nullable || other.nullable() || type_->id() == Type::NA || + other.type()->id() == Type::NA; + } else if (nullable_ != other.nullable()) { + return Status::Invalid("Unable to merge: Field ", name(), + " has incompatible nullability: ", nullable_, " vs ", + other.nullable()); } - std::shared_ptr promoted = MaybePromoteNullTypes(*this, other); - if (promoted) return promoted; - } + return std::make_shared(name_, promoted_type, nullable, metadata_); + } return Status::Invalid("Unable to merge: Field ", name(), " has incompatible types: ", type()->ToString(), " vs ", other.type()->ToString()); @@ -1668,7 +2088,8 @@ class SchemaBuilder::Impl { if (policy_ == CONFLICT_REPLACE) { fields_[i] = field; } else if (policy_ == CONFLICT_MERGE) { - ARROW_ASSIGN_OR_RAISE(fields_[i], fields_[i]->MergeWith(field)); + ARROW_ASSIGN_OR_RAISE(fields_[i], + fields_[i]->MergeWith(field, field_merge_options_)); } return Status::OK(); @@ -2238,6 +2659,12 @@ std::shared_ptr map(std::shared_ptr key_type, keys_sorted); } +std::shared_ptr map(std::shared_ptr key_field, + std::shared_ptr item_field, bool keys_sorted) { + return std::make_shared(std::move(key_field), std::move(item_field), + keys_sorted); +} + std::shared_ptr fixed_size_list(const std::shared_ptr& value_type, int32_t list_size) { return std::make_shared(value_type, list_size); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 463636b0537..8afa40bc010 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -303,14 +303,78 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// \brief Options that control the behavior of `MergeWith`. /// Options are to be added to allow type conversions, including integer /// widening, promotion from integer to float, or conversion to or from boolean. - struct MergeOptions { + struct ARROW_EXPORT MergeOptions : public util::ToStringOstreamable { /// If true, a Field of NullType can be unified with a Field of another type. /// The unified field will be of the other type and become nullable. /// Nullability will be promoted to the looser option (nullable if one is not /// nullable). bool promote_nullability = true; + /// Allow a decimal to be unified with another decimal of the same + /// width, adjusting scale and precision as appropriate. May fail + /// if the adjustment is not possible. + bool promote_decimal = false; + + /// Allow a decimal to be promoted to a float. The float type will + /// not itself be promoted (e.g. Decimal128 + Float32 = Float32). + bool promote_decimal_float = false; + + /// Allow an integer to be promoted to a decimal. + /// + /// May fail if the decimal has insufficient precision to + /// accomodate the integer. (See increase_decimal_precision.) + bool promote_integer_decimal = false; + + /// Allow an integer of a given bit width to be promoted to a + /// float; the result will be a float of an equal or greater bit + /// width to both of the inputs. + bool promote_integer_float = false; + + /// Allow an unsigned integer of a given bit width to be promoted + /// to a signed integer of the equal or greater bit width. + bool promote_integer_sign = false; + + /// Allow an integer, float, or decimal of a given bit width to be + /// promoted to an equivalent type of a greater bit width. + bool promote_numeric_width = false; + + /// Allow strings to be promoted to binary types. + bool promote_binary = false; + + /// Promote Date32 to Date64. + bool promote_date = false; + + /// Promote second to millisecond, etc. + bool promote_duration = false; + + /// Promote Time32 to Time64, or Time32(SECOND) to Time32(MILLI), etc. + bool promote_time = false; + + /// Promote second to millisecond, etc. + bool promote_timestamp = false; + + /// Promote dictionary index types to a common type, and unify the + /// value types. + bool promote_dictionary = false; + + /// Allow merging ordered and non-ordered dictionaries, else + /// error. The result will be ordered if and only if both inputs + /// are ordered. + bool promote_dictionary_ordered = false; + + /// Allow a type to be promoted to the Large variant. + bool promote_large = false; + + /// Recursively merge nested types. + bool promote_nested = false; + + /// Get default options. Only NullType will be merged with other types. static MergeOptions Defaults() { return MergeOptions(); } + /// Get permissive options. All options are enabled, except + /// promote_dictionary_ordered. + static MergeOptions Permissive(); + /// Get a human-readable representation of the options. + std::string ToString() const; }; /// \brief Merge the current field with a field of the same name. diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 45afd7af2e6..1e566fa9ebd 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -503,6 +503,14 @@ std::shared_ptr map(std::shared_ptr key_type, std::shared_ptr item_field, bool keys_sorted = false); +/// \brief Create a MapType instance from its key field and value field. +/// +/// The field override is provided to communicate nullability of the value. +ARROW_EXPORT +std::shared_ptr map(std::shared_ptr key_field, + std::shared_ptr item_field, + bool keys_sorted = false); + /// \brief Create a FixedSizeListType instance from its child Field type ARROW_EXPORT std::shared_ptr fixed_size_list(const std::shared_ptr& value_type, diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index c7ac5f6c7f2..b447c01833f 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -978,6 +978,114 @@ class TestUnifySchemas : public TestSchema { << lhs_field->ToString() << " vs " << rhs_field->ToString(); } } + + void CheckUnifyAsymmetric( + const std::shared_ptr& field1, const std::shared_ptr& field2, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + ARROW_SCOPED_TRACE("options: ", options); + ARROW_SCOPED_TRACE("field2: ", field2->ToString()); + ARROW_SCOPED_TRACE("field1: ", field1->ToString()); + ASSERT_OK_AND_ASSIGN(auto merged, field1->MergeWith(field2, options)); + AssertFieldEqual(merged, expected); + } + + void CheckUnify(const std::shared_ptr& field1, + const std::shared_ptr& field2, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + CheckUnifyAsymmetric(field1, field2, expected, options); + CheckUnifyAsymmetric(field2, field1, expected, options); + } + + void CheckUnifyFails( + const std::shared_ptr& field1, const std::shared_ptr& field2, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + ARROW_SCOPED_TRACE("options: ", options); + ARROW_SCOPED_TRACE("field2: ", field2->ToString()); + ARROW_SCOPED_TRACE("field1: ", field1->ToString()); + ASSERT_RAISES(Invalid, field1->MergeWith(field2, options)); + ASSERT_RAISES(Invalid, field2->MergeWith(field1, options)); + } + + void CheckUnify(const std::shared_ptr& left, + const std::shared_ptr& right, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnify(field1, field2, field("a", expected), options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right, /*nullable=*/false); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/false), options); + + field1 = field("a", left); + field2 = field("a", right, /*nullable=*/false); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/true), options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/true), options); + } + + void CheckUnifyAsymmetric( + const std::shared_ptr& left, const std::shared_ptr& right, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnifyAsymmetric(field1, field2, field("a", expected), options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right, /*nullable=*/false); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/false), + options); + + field1 = field("a", left); + field2 = field("a", right, /*nullable=*/false); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/true), + options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/true), + options); + } + + void CheckUnifyFails( + const std::shared_ptr& left, const std::shared_ptr& right, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnifyFails(field1, field2, options); + } + + void CheckUnify(const std::shared_ptr& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : to) { + CheckUnify(from, ty, ty, options); + } + } + + void CheckUnifyFails( + const std::shared_ptr& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : to) { + CheckUnifyFails(from, ty, options); + } + } + + void CheckUnifyFails( + const std::vector>& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : from) { + CheckUnifyFails(ty, to, options); + } + } }; TEST_F(TestUnifySchemas, EmptyInput) { ASSERT_RAISES(Invalid, UnifySchemas({})); } @@ -1069,6 +1177,242 @@ TEST_F(TestUnifySchemas, MoreSchemas) { utf8_field->WithNullable(true)})); } +TEST_F(TestUnifySchemas, Numeric) { + auto options = Field::MergeOptions::Defaults(); + options.promote_numeric_width = true; + options.promote_integer_float = true; + options.promote_integer_sign = true; + CheckUnify(uint8(), + {int8(), uint16(), int16(), uint32(), int32(), uint64(), int64(), float32(), + float64()}, + options); + CheckUnify(int8(), {int16(), int32(), int64(), float32(), float64()}, options); + CheckUnify(uint16(), + {int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}, + options); + CheckUnify(int16(), {int32(), int64(), float32(), float64()}, options); + CheckUnify(uint32(), {int32(), uint64(), int64(), float32(), float64()}, options); + CheckUnify(int32(), {int64(), float32(), float64()}, options); + CheckUnify(uint64(), {int64(), float64()}, options); + CheckUnify(int64(), {float64()}, options); + CheckUnify(float16(), {float32(), float64()}, options); + CheckUnify(float32(), {float64()}, options); + CheckUnify(uint64(), float32(), float64(), options); + CheckUnify(int64(), float32(), float64(), options); + + options.promote_integer_sign = false; + CheckUnify(uint8(), {uint16(), uint32(), uint64()}, options); + CheckUnify(int8(), {int16(), int32(), int64()}, options); + CheckUnifyFails(uint8(), {int8(), int16(), int32(), int64()}, options); + CheckUnify(uint16(), {uint32(), uint64()}, options); + CheckUnify(int16(), {int32(), int64()}, options); + CheckUnifyFails(uint16(), {int16(), int32(), int64()}, options); + CheckUnify(uint32(), {uint64()}, options); + CheckUnify(int32(), {int64()}, options); + CheckUnifyFails(uint32(), {int32(), int64()}, options); + CheckUnifyFails(uint64(), {int64()}, options); + + options.promote_integer_sign = true; + options.promote_integer_float = false; + CheckUnifyFails(IntTypes(), FloatingPointTypes(), options); + + options.promote_integer_float = true; + options.promote_numeric_width = false; + CheckUnifyFails(int8(), {int16(), int32(), int64()}, options); + CheckUnifyFails(int16(), {int32(), int64()}, options); + CheckUnifyFails(int32(), {int64()}, options); +} + +TEST_F(TestUnifySchemas, Decimal) { + auto options = Field::MergeOptions::Defaults(); + + options.promote_decimal_float = true; + CheckUnify(decimal128(3, 2), {float32(), float64()}, options); + CheckUnify(decimal256(3, 2), {float32(), float64()}, options); + + options.promote_integer_decimal = true; + CheckUnify(int32(), decimal128(3, 2), decimal128(3, 2), options); + CheckUnify(int32(), decimal128(3, -2), decimal128(3, -2), options); + + options.promote_decimal = true; + CheckUnify(decimal128(3, 2), decimal128(5, 2), decimal128(5, 2), options); + CheckUnify(decimal128(3, 2), decimal128(5, 3), decimal128(5, 3), options); + CheckUnify(decimal128(3, 2), decimal128(5, 1), decimal128(6, 2), options); + CheckUnify(decimal128(3, 2), decimal128(5, -2), decimal128(9, 2), options); + CheckUnify(decimal128(3, -2), decimal128(5, -2), decimal128(5, -2), options); + + // int32() is essentially decimal128(10, 0) + CheckUnify(int32(), decimal128(3, 2), decimal128(12, 2), options); + CheckUnify(int32(), decimal128(3, -2), decimal128(10, 0), options); + + CheckUnifyFails(decimal256(1, 0), decimal128(1, 0), options); + CheckUnifyFails(int64(), decimal128(38, 37), options); + + options.promote_numeric_width = true; + CheckUnify(decimal128(3, 2), decimal256(5, 2), decimal256(5, 2), options); + CheckUnify(int32(), decimal128(38, 37), decimal256(47, 37), options); + + CheckUnifyFails(int64(), decimal256(76, 75), options); +} + +TEST_F(TestUnifySchemas, Temporal) { + auto options = Field::MergeOptions::Defaults(); + + options.promote_date = true; + CheckUnify(date32(), {date64()}, options); + + options.promote_time = true; + CheckUnify(time32(TimeUnit::SECOND), + {time32(TimeUnit::MILLI), time64(TimeUnit::MICRO), time64(TimeUnit::NANO)}, + options); + CheckUnify(time32(TimeUnit::MILLI), {time64(TimeUnit::MICRO), time64(TimeUnit::NANO)}, + options); + CheckUnify(time64(TimeUnit::MICRO), {time64(TimeUnit::NANO)}, options); + + options.promote_duration = true; + CheckUnify( + duration(TimeUnit::SECOND), + {duration(TimeUnit::MILLI), duration(TimeUnit::MICRO), duration(TimeUnit::NANO)}, + options); + CheckUnify(duration(TimeUnit::MILLI), + {duration(TimeUnit::MICRO), duration(TimeUnit::NANO)}, options); + CheckUnify(duration(TimeUnit::MICRO), {duration(TimeUnit::NANO)}, options); + + options.promote_timestamp = true; + CheckUnify( + timestamp(TimeUnit::SECOND), + {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO)}, + options); + CheckUnify(timestamp(TimeUnit::MILLI), + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO)}, options); + CheckUnify(timestamp(TimeUnit::MICRO), {timestamp(TimeUnit::NANO)}, options); + + CheckUnifyFails(timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC"), + options); + CheckUnifyFails(timestamp(TimeUnit::SECOND, "America/New_York"), + timestamp(TimeUnit::SECOND, "UTC"), options); +} + +TEST_F(TestUnifySchemas, Binary) { + auto options = Field::MergeOptions::Defaults(); + options.promote_large = true; + options.promote_binary = true; + CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}, options); + CheckUnify(binary(), {large_binary()}, options); + CheckUnify(fixed_size_binary(2), {fixed_size_binary(2), binary(), large_binary()}, + options); + CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary(), options); + + options.promote_large = false; + CheckUnifyFails({utf8(), binary()}, {large_utf8(), large_binary()}); + CheckUnifyFails(fixed_size_binary(2), BaseBinaryTypes()); + + options.promote_binary = false; + CheckUnifyFails(utf8(), {binary(), large_binary(), fixed_size_binary(2)}); +} + +TEST_F(TestUnifySchemas, List) { + auto options = Field::MergeOptions::Defaults(); + options.promote_numeric_width = true; + CheckUnifyFails(fixed_size_list(int8(), 2), + {fixed_size_list(int8(), 3), list(int8()), large_list(int8())}, + options); + + options.promote_large = true; + CheckUnify(list(int8()), {large_list(int8())}, options); + CheckUnify(fixed_size_list(int8(), 2), {list(int8()), large_list(int8())}, options); + + options.promote_nested = true; + CheckUnify(list(int8()), {list(int16()), list(int32()), list(int64())}, options); + CheckUnify(fixed_size_list(int8(), 2), + {fixed_size_list(int16(), 2), list(int16()), list(int32()), list(int64())}, + options); + + auto ty = list(field("foo", int8(), /*nullable=*/false)); + CheckUnifyAsymmetric(ty, list(int8()), list(field("foo", int8(), /*nullable=*/true)), + options); + CheckUnifyAsymmetric(ty, list(field("bar", int16(), /*nullable=*/false)), + list(field("foo", int16(), /*nullable=*/false)), options); +} + +TEST_F(TestUnifySchemas, Map) { + auto options = Field::MergeOptions::Defaults(); + options.promote_nested = true; + options.promote_numeric_width = true; + + CheckUnify(map(int8(), int32()), + {map(int8(), int64()), map(int16(), int32()), map(int64(), int64())}, + options); + + // Do not test field names, since MapType intentionally ignores them in comparisons + // See ARROW-7173, ARROW-14999 + auto ty = map(field("key", int8(), /*nullable=*/false), + field("value", int32(), /*nullable=*/false)); + CheckUnify(ty, map(int8(), int32()), + map(field("key", int8(), /*nullable=*/true), + field("value", int32(), /*nullable=*/true)), + options); + CheckUnify(ty, + map(field("key", int16(), /*nullable=*/false), + field("value", int64(), /*nullable=*/false)), + map(field("key", int16(), /*nullable=*/false), + field("value", int64(), /*nullable=*/false)), + options); +} + +TEST_F(TestUnifySchemas, Struct) { + auto options = Field::MergeOptions::Defaults(); + options.promote_nested = true; + options.promote_numeric_width = true; + options.promote_binary = true; + + CheckUnify(struct_({}), struct_({field("a", int8())}), struct_({field("a", int8())}), + options); + + CheckUnifyAsymmetric(struct_({field("b", utf8())}), struct_({field("a", int8())}), + struct_({field("b", utf8()), field("a", int8())}), options); + CheckUnifyAsymmetric(struct_({field("a", int8())}), struct_({field("b", utf8())}), + struct_({field("a", int8()), field("b", utf8())}), options); + + CheckUnify(struct_({field("b", utf8())}), struct_({field("b", binary())}), + struct_({field("b", binary())}), options); + + CheckUnifyAsymmetric( + struct_({field("a", int8()), field("b", utf8()), field("a", int64())}), + struct_({field("b", binary())}), + struct_({field("a", int8()), field("b", binary()), field("a", int64())}), options); + + ASSERT_RAISES( + Invalid, + field("foo", struct_({field("a", int8()), field("b", utf8()), field("a", int64())})) + ->MergeWith(field("foo", struct_({field("a", int64())})), options)); +} + +TEST_F(TestUnifySchemas, Dictionary) { + auto options = Field::MergeOptions::Defaults(); + options.promote_dictionary = true; + options.promote_large = true; + + CheckUnify(dictionary(int8(), utf8()), + { + dictionary(int64(), utf8()), + dictionary(int8(), large_utf8()), + }, + options); + CheckUnify(dictionary(int8(), utf8(), /*ordered=*/true), + { + dictionary(int64(), utf8(), /*ordered=*/true), + dictionary(int8(), large_utf8(), /*ordered=*/true), + }, + options); + CheckUnifyFails(dictionary(int8(), utf8()), + dictionary(int8(), utf8(), /*ordered=*/true), options); + + options.promote_dictionary_ordered = true; + CheckUnify(dictionary(int8(), utf8()), dictionary(int8(), utf8(), /*ordered=*/true), + dictionary(int8(), utf8(), /*ordered=*/false), options); +} + TEST_F(TestUnifySchemas, IncompatibleTypes) { auto int32_field = field("f", int32()); auto uint8_field = field("f", uint8(), false); diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 4b4cb5d15d3..26ea9e79c0c 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -879,6 +879,17 @@ static inline bool is_decimal(Type::type type_id) { return false; } +static inline bool is_time(Type::type type_id) { + switch (type_id) { + case Type::TIME32: + case Type::TIME64: + return true; + default: + break; + } + return false; +} + static inline bool is_primitive(Type::type type_id) { switch (type_id) { case Type::BOOL: @@ -944,6 +955,39 @@ static inline bool is_large_binary_like(Type::type type_id) { return false; } +static inline bool is_binary(Type::type type_id) { + switch (type_id) { + case Type::BINARY: + case Type::LARGE_BINARY: + return true; + default: + break; + } + return false; +} + +static inline bool is_string(Type::type type_id) { + switch (type_id) { + case Type::STRING: + case Type::LARGE_STRING: + return true; + default: + break; + } + return false; +} + +static inline bool is_var_size_list(Type::type type_id) { + switch (type_id) { + case Type::LIST: + case Type::LARGE_LIST: + return true; + default: + break; + } + return false; +} + static inline bool is_dictionary(Type::type type_id) { return type_id == Type::DICTIONARY; } diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 07ef7f4b078..69ea79b50cf 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -115,6 +115,7 @@ def show_versions(): DictionaryMemo, KeyValueMetadata, Field, + FieldMergeOptions, Schema, schema, unify_schemas, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index de0d3a74dfb..6e71b5a51aa 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1013,11 +1013,11 @@ cdef class Array(_PandasConvertible): Parameters ---------- indent : int, default 2 - How much to indent the internal items in the string to + How much to indent the internal items in the string to the right, by default ``2``. top_level_indent : int, default 0 How much to indent right the entire content of the array, - by default ``0``. + by default ``0``. window : int How many items to preview at the begin and end of the array when the arrays is bigger than the window. diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 1e6c741ac30..3202831f641 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -387,12 +387,16 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: int scale() cdef cppclass CField" arrow::Field": - cppclass CMergeOptions "arrow::Field::MergeOptions": + cppclass CMergeOptions "MergeOptions": + CMergeOptions() c_bool promote_nullability @staticmethod CMergeOptions Defaults() + @staticmethod + CMergeOptions Permissive() + const c_string& name() shared_ptr[CDataType] type() c_bool nullable() @@ -483,7 +487,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: shared_ptr[CSchema] RemoveMetadata() CResult[shared_ptr[CSchema]] UnifySchemas( - const vector[shared_ptr[CSchema]]& schemas) + const vector[shared_ptr[CSchema]]& schemas, + CField.CMergeOptions field_merge_options) cdef cppclass PrettyPrintOptions: PrettyPrintOptions() diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 5f42d71c7e3..e86782daf7f 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2349,7 +2349,8 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None): "Expected pandas DataFrame, python dictionary or list of arrays") -def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): +def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None, + FieldMergeOptions field_merge_options=None): """ Concatenate pyarrow.Table objects. @@ -2371,9 +2372,13 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): tables : iterable of pyarrow.Table objects Pyarrow tables to concatenate into a single Table. promote : bool, default False - If True, concatenate tables with null-filling and null type promotion. + If True, concatenate tables with null-filling and type promotion. + See field_merge_options for the type promotion behavior. memory_pool : MemoryPool, default None For memory allocations, if required, otherwise use default pool. + field_merge_options : FieldMergeOptions, default None + The type promotion options; by default, null and only null can + be unified with another type. """ cdef: vector[shared_ptr[CTable]] c_tables @@ -2386,6 +2391,9 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): for table in tables: c_tables.push_back(table.sp_table) + if field_merge_options: + options.field_merge_options = field_merge_options.c_options + with nogil: options.unify_schemas = promote c_result_table = GetResultValue( diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py index f26eaaf5fc1..b208a995833 100644 --- a/python/pyarrow/tests/test_schema.py +++ b/python/pyarrow/tests/test_schema.py @@ -718,6 +718,10 @@ def test_schema_merge(): result = pa.unify_schemas((a, b, c)) assert result.equals(expected) + result = pa.unify_schemas( + [b, d], options=pa.FieldMergeOptions.permissive()) + assert result.equals(d) + def test_undecodable_metadata(): # ARROW-10214: undecodable metadata shouldn't fail repr() diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index e5e27332d13..010c2a10cf5 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1189,6 +1189,15 @@ def test_concat_tables_with_promotion(): pa.array([None, None, 1.0, 2.0], type=pa.float32()), ], ["int64_field", "float_field"])) + t3 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.int32())], ["int64_field"]) + result = pa.concat_tables( + [t1, t3], promote=True, + field_merge_options=pa.FieldMergeOptions.permissive()) + assert result.equals(pa.Table.from_arrays([ + pa.array([1, 2, 1, 2], type=pa.int64()), + ], ["int64_field"])) + def test_concat_tables_with_promotion_error(): t1 = pa.Table.from_arrays( diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 32d70887aab..e81a1d5534e 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1118,6 +1118,33 @@ cdef KeyValueMetadata ensure_metadata(object meta, c_bool allow_none=False): return KeyValueMetadata(meta) +cdef class FieldMergeOptions(_Weakrefable): + """ + Options controlling how to merge the types of two fields. + + By default, types must match exactly, except the null type can be + merged with any other type. + + """ + + cdef: + CField.CMergeOptions c_options + + __slots__ = () + + def __init__(self, *): + self.c_options = CField.CMergeOptions.Defaults() + + @staticmethod + def permissive(): + """ + Allow merging generally compatible types (e.g. float64 and int64). + """ + cdef FieldMergeOptions options = FieldMergeOptions() + options.c_options = CField.CMergeOptions.Permissive() + return options + + cdef class Field(_Weakrefable): """ A named field, with a data type, nullability, and optional metadata. @@ -1783,13 +1810,13 @@ cdef class Schema(_Weakrefable): return self.__str__() -def unify_schemas(schemas): +def unify_schemas(schemas, *, options=None): """ Unify schemas by merging fields by name. The resulting schema will contain the union of fields from all schemas. Fields with the same name will be merged. Note that two fields with - different types will fail merging. + different types will fail merging by default. - The unified field will inherit the metadata from the schema where that field is first defined. @@ -1803,6 +1830,8 @@ def unify_schemas(schemas): ---------- schemas : list of Schema Schemas to merge into a single one. + options : FieldMergeOptions, optional + Options for merging duplicate fields. Returns ------- @@ -1816,10 +1845,16 @@ def unify_schemas(schemas): """ cdef: Schema schema + CField.CMergeOptions c_options vector[shared_ptr[CSchema]] c_schemas for schema in schemas: c_schemas.push_back(pyarrow_unwrap_schema(schema)) - return pyarrow_wrap_schema(GetResultValue(UnifySchemas(c_schemas))) + if options: + c_options = ( options).c_options + else: + c_options = CField.CMergeOptions.Defaults() + return pyarrow_wrap_schema( + GetResultValue(UnifySchemas(c_schemas, c_options))) cdef dict _type_cache = {}