diff --git a/datafusion/core/tests/sql/union.rs b/datafusion/core/tests/sql/union.rs index 804833bb9d77..e10f6e237751 100644 --- a/datafusion/core/tests/sql/union.rs +++ b/datafusion/core/tests/sql/union.rs @@ -98,3 +98,49 @@ async fn union_with_type_coercion() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_union_upcast_types() -> Result<()> { + let config = SessionConfig::new() + .with_repartition_windows(false) + .with_target_partitions(1); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, c9 FROM aggregate_test_100 + UNION ALL + SELECT c1, c3 FROM aggregate_test_100 + ORDER BY c9 DESC LIMIT 5"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + + let expected_logical_plan = vec![ + "Limit: skip=0, fetch=5 [c1:Utf8, c9:Int64]", + " Sort: c9 DESC NULLS FIRST [c1:Utf8, c9:Int64]", + " Union [c1:Utf8, c9:Int64]", + " Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c9 AS Int64) AS c9 [c1:Utf8, c9:Int64]", + " TableScan: aggregate_test_100 [c1:Utf8, c2:UInt32, c3:Int8, c4:Int16, c5:Int32, c6:Int64, c7:UInt8, c8:UInt16, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", + " Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c3 AS Int64) AS c9 [c1:Utf8, c9:Int64]", + " TableScan: aggregate_test_100 [c1:Utf8, c2:UInt32, c3:Int8, c4:Int16, c5:Int32, c6:Int64, c7:UInt8, c8:UInt16, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", + ]; + let formatted_logical_plan = + dataframe.logical_plan().display_indent_schema().to_string(); + let actual_logical_plan: Vec<&str> = formatted_logical_plan.trim().lines().collect(); + assert_eq!(expected_logical_plan, actual_logical_plan, "\n\nexpected:\n\n{expected_logical_plan:#?}\nactual:\n\n{actual_logical_plan:#?}\n\n"); + + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+----+------------+", + "| c1 | c9 |", + "+----+------------+", + "| c | 4268716378 |", + "| e | 4229654142 |", + "| d | 4216440507 |", + "| e | 4144173353 |", + "| b | 4076864659 |", + "+----+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 3a2ca9c79922..ba24f4a901d5 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -238,13 +238,36 @@ fn comparison_binary_numeric_coercion( (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), (Float64, _) | (_, Float64) => Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), - (Int64, _) | (_, Int64) => Some(Int64), - (Int32, _) | (_, Int32) => Some(Int32), - (Int16, _) | (_, Int16) => Some(Int16), - (Int8, _) | (_, Int8) => Some(Int8), + // The following match arms encode the following logic: Given the two + // integral types, we choose the narrowest possible integral type that + // accommodates all values of both types. Note that some information + // loss is inevitable when we have a signed type and a `UInt64`, in + // which case we use `Int64`;i.e. the widest signed integral type. + (Int64, _) + | (_, Int64) + | (UInt64, Int8) + | (Int8, UInt64) + | (UInt64, Int16) + | (Int16, UInt64) + | (UInt64, Int32) + | (Int32, UInt64) + | (UInt32, Int8) + | (Int8, UInt32) + | (UInt32, Int16) + | (Int16, UInt32) + | (UInt32, Int32) + | (Int32, UInt32) => Some(Int64), (UInt64, _) | (_, UInt64) => Some(UInt64), + (Int32, _) + | (_, Int32) + | (UInt16, Int16) + | (Int16, UInt16) + | (UInt16, Int8) + | (Int8, UInt16) => Some(Int32), (UInt32, _) | (_, UInt32) => Some(UInt32), + (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16), (UInt16, _) | (_, UInt16) => Some(UInt16), + (Int8, _) | (_, Int8) => Some(Int8), (UInt8, _) | (_, UInt8) => Some(UInt8), _ => None, }