diff --git a/benchmarks/expected-plans/q11.txt b/benchmarks/expected-plans/q11.txt index b408340a32a0..0e886e2e74b7 100644 --- a/benchmarks/expected-plans/q11.txt +++ b/benchmarks/expected-plans/q11.txt @@ -1,6 +1,6 @@ Sort: value DESC NULLS FIRST Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value - Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value + Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15)) CrossJoin: Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey @@ -9,7 +9,7 @@ Sort: value DESC NULLS FIRST TableScan: supplier projection=[s_suppkey, s_nationkey] Filter: nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name] - Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1 + Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value, alias=__sq_1 Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey diff --git a/benchmarks/expected-plans/q14.txt b/benchmarks/expected-plans/q14.txt index c410363a5821..edafe4608210 100644 --- a/benchmarks/expected-plans/q14.txt +++ b/benchmarks/expected-plans/q14.txt @@ -1,4 +1,4 @@ -Projection: CAST(Decimal128(Some(1000000000000000000000),38,19) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Decimal128(38, 19)) AS Decimal128(38, 38)) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Decimal128(38, 38)) AS promo_revenue +Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(100),23,2) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(100),23,2) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] Projection: CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice, part.p_type Inner Join: lineitem.l_partkey = part.p_partkey diff --git a/benchmarks/expected-plans/q20.txt b/benchmarks/expected-plans/q20.txt index e5398325e966..0d095a735c29 100644 --- a/benchmarks/expected-plans/q20.txt +++ b/benchmarks/expected-plans/q20.txt @@ -6,14 +6,14 @@ Sort: supplier.s_name ASC NULLS LAST Filter: nation.n_name = Utf8("CANADA") TableScan: nation projection=[n_nationkey, n_name] Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2 - Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value + Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] Projection: part.p_partkey AS p_partkey, alias=__sq_1 Filter: part.p_name LIKE Utf8("forest%") TableScan: part projection=[p_partkey, p_name] - Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3 + Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3 Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate] \ No newline at end of file diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs index 2e3e3d2abdfa..e0c2c1773e48 100644 --- a/datafusion/core/tests/sql/decimal.rs +++ b/datafusion/core/tests/sql/decimal.rs @@ -879,3 +879,37 @@ async fn decimal_null_array_scalar_comparison() -> Result<()> { assert_eq!(&DataType::Boolean, actual[0].column(0).data_type()); Ok(()) } + +#[tokio::test] +async fn decimal_multiply_float() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "select cast(400420638.54 as decimal(12,2));"; + let actual = execute_to_batches(&ctx, sql).await; + + assert_eq!( + &DataType::Decimal128(12, 2), + actual[0].schema().field(0).data_type() + ); + let expected = vec![ + "+-----------------------+", + "| Float64(400420638.54) |", + "+-----------------------+", + "| 400420638.54 |", + "+-----------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select cast(400420638.54 as decimal(12,2)) * 1.0;"; + let actual = execute_to_batches(&ctx, sql).await; + assert_eq!(&DataType::Float64, actual[0].schema().field(0).data_type()); + let expected = vec![ + "+------------------------------------+", + "| Float64(400420638.54) * Float64(1) |", + "+------------------------------------+", + "| 400420638.54 |", + "+------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index ed65d43919b3..4fb97d5eb3bb 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -328,14 +328,14 @@ order by s_name; Filter: nation.n_name = Utf8("CANADA") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2 - Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value + Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] Projection: part.p_partkey AS p_partkey, alias=__sq_1 Filter: part.p_name LIKE Utf8("forest%") TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] - Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3 + Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3 Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] Filter: lineitem.l_shipdate >= Date32("8766") TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"# @@ -443,7 +443,7 @@ order by value desc; let actual = format!("{}", plan.display_indent()); let expected = r#"Sort: value DESC NULLS FIRST Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value - Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value + Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15)) CrossJoin: Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey @@ -452,7 +452,7 @@ order by value desc; TableScan: supplier projection=[s_suppkey, s_nationkey] Filter: nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] - Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1 + Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value, alias=__sq_1 Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d65ed5228754..ce169f6ec253 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1488,6 +1488,7 @@ impl Subquery { pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> { match plan { Expr::ScalarSubquery(it) => Ok(it), + Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()), _ => plan_err!("Could not coerce into ScalarSubquery!"), } } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 2a125d56b39f..45510cb03e0e 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -333,6 +333,8 @@ fn mathematics_numerical_coercion( (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { Some(dec_type.clone()) } + (Decimal128(_, _), Float32 | Float64) => Some(Float64), + (Float32 | Float64, Decimal128(_, _)) => Some(Float64), (Decimal128(_, _), _) => { let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type); match converted_decimal_type { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 8b93f49c1a03..4aa5d5e1e70f 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -2574,9 +2574,22 @@ mod tests { let right_expr = if right.data_type().eq(&op_type) { col("b", schema)? } else { - try_cast(col("b", schema)?, schema, op_type)? + try_cast(col("b", schema)?, schema, op_type.clone())? }; - let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); + + let coerced_schema = Schema::new(vec![ + Field::new( + schema.field(0).name(), + op_type.clone(), + schema.field(0).is_nullable(), + ), + Field::new( + schema.field(1).name(), + op_type, + schema.field(1).is_nullable(), + ), + ]); + let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema); let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -2704,6 +2717,125 @@ mod tests { Ok(()) } + #[test] + fn arithmetic_decimal_float_expr_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal128(10, 2), true), + ])); + let value: i128 = 123; + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value as i128), // 1.23 + None, + Some((value - 1) as i128), // 1.22 + Some((value + 1) as i128), // 1.24 + ], + 10, + 2, + )) as ArrayRef; + let float64_array = Arc::new(Float64Array::from(vec![ + Some(123.0), + Some(122.0), + Some(123.0), + Some(124.0), + ])) as ArrayRef; + + // add: float64 array add decimal array + let expect = Arc::new(Float64Array::from(vec![ + Some(124.23), + None, + Some(124.22), + Some(125.24), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Plus, + expect, + ) + .unwrap(); + + // subtract: decimal array subtract float64 array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal128(10, 2), true), + ])); + let expect = Arc::new(Float64Array::from(vec![ + Some(121.77), + None, + Some(121.78), + Some(122.76), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Minus, + expect, + ) + .unwrap(); + + // multiply: decimal array multiply float64 array + let expect = Arc::new(Float64Array::from(vec![ + Some(151.29), + None, + Some(150.06), + Some(153.76), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Multiply, + expect, + ) + .unwrap(); + + // divide: float64 array divide decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal128(10, 2), true), + ])); + let expect = Arc::new(Float64Array::from(vec![ + Some(100.0), + None, + Some(100.81967213114754), + Some(100.0), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Divide, + expect, + ) + .unwrap(); + + // modulus: float64 array modulus decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal128(10, 2), true), + ])); + let expect = Arc::new(Float64Array::from(vec![ + Some(1.7763568394002505e-15), + None, + Some(1.0000000000000027), + Some(8.881784197001252e-16), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Modulo, + expect, + ) + .unwrap(); + + Ok(()) + } + #[test] fn bitwise_array_test() -> Result<()> { let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;