From af10480b01233238852f07d059ecaa88b2040591 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 30 Oct 2022 15:14:46 -0700 Subject: [PATCH 1/5] Fix Decimal and Floating type coerce rule --- benchmarks/expected-plans/q11.txt | 4 +- benchmarks/expected-plans/q14.txt | 2 +- benchmarks/expected-plans/q20.txt | 4 +- datafusion/core/tests/sql/decimal.rs | 34 +++++ datafusion/core/tests/sql/subqueries.rs | 8 +- datafusion/expr/src/logical_plan/plan.rs | 1 + datafusion/expr/src/type_coercion/binary.rs | 2 + .../physical-expr/src/expressions/binary.rs | 136 +++++++++++++++++- 8 files changed, 180 insertions(+), 11 deletions(-) diff --git a/benchmarks/expected-plans/q11.txt b/benchmarks/expected-plans/q11.txt index b408340a32a0..0e886e2e74b7 100644 --- a/benchmarks/expected-plans/q11.txt +++ b/benchmarks/expected-plans/q11.txt @@ -1,6 +1,6 @@ Sort: value DESC NULLS FIRST Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value - Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value + Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15)) CrossJoin: Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey @@ -9,7 +9,7 @@ Sort: value DESC NULLS FIRST TableScan: supplier projection=[s_suppkey, s_nationkey] Filter: nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name] - Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1 + Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value, alias=__sq_1 Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey diff --git a/benchmarks/expected-plans/q14.txt b/benchmarks/expected-plans/q14.txt index c410363a5821..edafe4608210 100644 --- a/benchmarks/expected-plans/q14.txt +++ b/benchmarks/expected-plans/q14.txt @@ -1,4 +1,4 @@ -Projection: CAST(Decimal128(Some(1000000000000000000000),38,19) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Decimal128(38, 19)) AS Decimal128(38, 38)) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Decimal128(38, 38)) AS promo_revenue +Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(100),23,2) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(100),23,2) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] Projection: CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal128(Some(100),23,2)CAST(lineitem.l_extendedprice AS Decimal128(38, 4))lineitem.l_extendedprice, part.p_type Inner Join: lineitem.l_partkey = part.p_partkey diff --git a/benchmarks/expected-plans/q20.txt b/benchmarks/expected-plans/q20.txt index e5398325e966..0d095a735c29 100644 --- a/benchmarks/expected-plans/q20.txt +++ b/benchmarks/expected-plans/q20.txt @@ -6,14 +6,14 @@ Sort: supplier.s_name ASC NULLS LAST Filter: nation.n_name = Utf8("CANADA") TableScan: nation projection=[n_nationkey, n_name] Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2 - Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value + Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] Projection: part.p_partkey AS p_partkey, alias=__sq_1 Filter: part.p_name LIKE Utf8("forest%") TableScan: part projection=[p_partkey, p_name] - Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3 + Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3 Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate] \ No newline at end of file diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs index 2e3e3d2abdfa..e0c2c1773e48 100644 --- a/datafusion/core/tests/sql/decimal.rs +++ b/datafusion/core/tests/sql/decimal.rs @@ -879,3 +879,37 @@ async fn decimal_null_array_scalar_comparison() -> Result<()> { assert_eq!(&DataType::Boolean, actual[0].column(0).data_type()); Ok(()) } + +#[tokio::test] +async fn decimal_multiply_float() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "select cast(400420638.54 as decimal(12,2));"; + let actual = execute_to_batches(&ctx, sql).await; + + assert_eq!( + &DataType::Decimal128(12, 2), + actual[0].schema().field(0).data_type() + ); + let expected = vec![ + "+-----------------------+", + "| Float64(400420638.54) |", + "+-----------------------+", + "| 400420638.54 |", + "+-----------------------+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "select cast(400420638.54 as decimal(12,2)) * 1.0;"; + let actual = execute_to_batches(&ctx, sql).await; + assert_eq!(&DataType::Float64, actual[0].schema().field(0).data_type()); + let expected = vec![ + "+------------------------------------+", + "| Float64(400420638.54) * Float64(1) |", + "+------------------------------------+", + "| 400420638.54 |", + "+------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index ed65d43919b3..4fb97d5eb3bb 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -328,14 +328,14 @@ order by s_name; Filter: nation.n_name = Utf8("CANADA") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2 - Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value + Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] Projection: part.p_partkey AS p_partkey, alias=__sq_1 Filter: part.p_name LIKE Utf8("forest%") TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] - Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3 + Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3 Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] Filter: lineitem.l_shipdate >= Date32("8766") TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"# @@ -443,7 +443,7 @@ order by value desc; let actual = format!("{}", plan.display_indent()); let expected = r#"Sort: value DESC NULLS FIRST Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value - Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value + Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15)) CrossJoin: Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey @@ -452,7 +452,7 @@ order by value desc; TableScan: supplier projection=[s_suppkey, s_nationkey] Filter: nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] - Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1 + Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value, alias=__sq_1 Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d65ed5228754..ce169f6ec253 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1488,6 +1488,7 @@ impl Subquery { pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> { match plan { Expr::ScalarSubquery(it) => Ok(it), + Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()), _ => plan_err!("Could not coerce into ScalarSubquery!"), } } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 2a125d56b39f..45510cb03e0e 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -333,6 +333,8 @@ fn mathematics_numerical_coercion( (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { Some(dec_type.clone()) } + (Decimal128(_, _), Float32 | Float64) => Some(Float64), + (Float32 | Float64, Decimal128(_, _)) => Some(Float64), (Decimal128(_, _), _) => { let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type); match converted_decimal_type { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 8b93f49c1a03..4aa5d5e1e70f 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -2574,9 +2574,22 @@ mod tests { let right_expr = if right.data_type().eq(&op_type) { col("b", schema)? } else { - try_cast(col("b", schema)?, schema, op_type)? + try_cast(col("b", schema)?, schema, op_type.clone())? }; - let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); + + let coerced_schema = Schema::new(vec![ + Field::new( + schema.field(0).name(), + op_type.clone(), + schema.field(0).is_nullable(), + ), + Field::new( + schema.field(1).name(), + op_type, + schema.field(1).is_nullable(), + ), + ]); + let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema); let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -2704,6 +2717,125 @@ mod tests { Ok(()) } + #[test] + fn arithmetic_decimal_float_expr_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal128(10, 2), true), + ])); + let value: i128 = 123; + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value as i128), // 1.23 + None, + Some((value - 1) as i128), // 1.22 + Some((value + 1) as i128), // 1.24 + ], + 10, + 2, + )) as ArrayRef; + let float64_array = Arc::new(Float64Array::from(vec![ + Some(123.0), + Some(122.0), + Some(123.0), + Some(124.0), + ])) as ArrayRef; + + // add: float64 array add decimal array + let expect = Arc::new(Float64Array::from(vec![ + Some(124.23), + None, + Some(124.22), + Some(125.24), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Plus, + expect, + ) + .unwrap(); + + // subtract: decimal array subtract float64 array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal128(10, 2), true), + ])); + let expect = Arc::new(Float64Array::from(vec![ + Some(121.77), + None, + Some(121.78), + Some(122.76), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Minus, + expect, + ) + .unwrap(); + + // multiply: decimal array multiply float64 array + let expect = Arc::new(Float64Array::from(vec![ + Some(151.29), + None, + Some(150.06), + Some(153.76), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Multiply, + expect, + ) + .unwrap(); + + // divide: float64 array divide decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal128(10, 2), true), + ])); + let expect = Arc::new(Float64Array::from(vec![ + Some(100.0), + None, + Some(100.81967213114754), + Some(100.0), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Divide, + expect, + ) + .unwrap(); + + // modulus: float64 array modulus decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal128(10, 2), true), + ])); + let expect = Arc::new(Float64Array::from(vec![ + Some(1.7763568394002505e-15), + None, + Some(1.0000000000000027), + Some(8.881784197001252e-16), + ])) as ArrayRef; + apply_arithmetic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Modulo, + expect, + ) + .unwrap(); + + Ok(()) + } + #[test] fn bitwise_array_test() -> Result<()> { let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; From 4d06a3458e111a9b6266374653d5f298d164316c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 31 Oct 2022 08:28:54 -0600 Subject: [PATCH 2/5] Enable more queries in benchmark verification tests --- benchmarks/src/bin/tpch.rs | 34 +++++++++++++++++----- benchmarks/src/tpch.rs | 59 +++++++++++++++++++++++++------------- 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index b9afe4d6a17e..948c178169c8 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -668,7 +668,6 @@ mod tests { } #[cfg(feature = "ci")] - #[ignore] // TODO produces correct result but has rounding error #[tokio::test] async fn verify_q9() -> Result<()> { verify_query(9).await @@ -681,7 +680,6 @@ mod tests { } #[cfg(feature = "ci")] - #[ignore] // https://github.com/apache/arrow-datafusion/issues/4023 #[tokio::test] async fn verify_q11() -> Result<()> { verify_query(11).await @@ -700,7 +698,6 @@ mod tests { } #[cfg(feature = "ci")] - #[ignore] // https://github.com/apache/arrow-datafusion/issues/4025 #[tokio::test] async fn verify_q14() -> Result<()> { verify_query(14).await @@ -719,7 +716,6 @@ mod tests { } #[cfg(feature = "ci")] - #[ignore] // https://github.com/apache/arrow-datafusion/issues/4026 #[tokio::test] async fn verify_q17() -> Result<()> { verify_query(17).await @@ -896,6 +892,7 @@ mod tests { #[cfg(feature = "ci")] async fn verify_query(n: usize) -> Result<()> { use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::ScalarValue; use datafusion::logical_expr::expr::Cast; use datafusion::logical_expr::Expr; use std::env; @@ -990,7 +987,12 @@ mod tests { } data_type => data_type == e.data_type(), }); - assert!(schema_matches); + if !schema_matches { + panic!( + "expected_fields: {:?}\ntransformed_fields: {:?}", + expected_fields, transformed_fields + ) + } // convert both datasets to Vec> for simple comparison let expected_vec = result_vec(&expected); @@ -1000,8 +1002,26 @@ mod tests { assert_eq!(expected_vec.len(), actual_vec.len()); // compare each row. this works as all TPC-H queries have deterministically ordered results - for i in 0..actual_vec.len() { - assert_eq!(expected_vec[i], actual_vec[i]); + for i in 0..expected_vec.len() { + let expected_row = &expected_vec[i]; + let actual_row = &actual_vec[i]; + assert_eq!(expected_row.len(), actual_row.len()); + + for j in 0..expected.len() { + match (&expected_row[j], &actual_row[j]) { + (ScalarValue::Float64(Some(l)), ScalarValue::Float64(Some(r))) => { + // allow for rounding errors until we move to decimal types + let tolerance = 1.0; + if (l - r).abs() > tolerance { + panic!( + "Expected: {}; Actual: {}; Tolerance: {}", + l, r, tolerance + ) + } + } + (l, r) => assert_eq!(format!("{:?}", l), format!("{:?}", r)), + } + } } Ok(()) diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs index 46c53edf120e..ad61de8a3b0b 100644 --- a/benchmarks/src/tpch.rs +++ b/benchmarks/src/tpch.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::ArrayRef; +use arrow::array::{ + Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, Int64Array, + StringArray, +}; use arrow::record_batch::RecordBatch; use std::fs; use std::ops::{Div, Mul}; @@ -23,7 +26,7 @@ use std::path::Path; use std::sync::Arc; use std::time::Instant; -use datafusion::arrow::util::display::array_value_to_string; +use datafusion::common::ScalarValue; use datafusion::logical_expr::Cast; use datafusion::prelude::*; use datafusion::{ @@ -229,11 +232,7 @@ pub fn get_answer_schema(n: usize) -> Schema { Field::new("custdist", DataType::Int64, true), ]), - 14 => Schema::new(vec![Field::new( - "promo_revenue", - DataType::Decimal128(38, 2), - true, - )]), + 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), 15 => Schema::new(vec![ Field::new("s_suppkey", DataType::Int64, true), @@ -250,11 +249,7 @@ pub fn get_answer_schema(n: usize) -> Schema { Field::new("supplier_cnt", DataType::Int64, true), ]), - 17 => Schema::new(vec![Field::new( - "avg_yearly", - DataType::Decimal128(38, 2), - true, - )]), + 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), 18 => Schema::new(vec![ Field::new("c_name", DataType::Utf8, true), @@ -389,14 +384,14 @@ pub async fn convert_tbl( /// Converts the results into a 2d array of strings, `result[row][column]` /// Special cases nulls to NULL for testing -pub fn result_vec(results: &[RecordBatch]) -> Vec> { +pub fn result_vec(results: &[RecordBatch]) -> Vec> { let mut result = vec![]; for batch in results { for row_index in 0..batch.num_rows() { let row_vec = batch .columns() .iter() - .map(|column| col_str(column, row_index)) + .map(|column| col_to_scalar(column, row_index)) .collect(); result.push(row_vec); } @@ -422,13 +417,37 @@ pub fn string_schema(schema: Schema) -> Schema { ) } -/// Specialised String representation -fn col_str(column: &ArrayRef, row_index: usize) -> String { +fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue { if column.is_null(row_index) { - return "NULL".to_string(); + return ScalarValue::Null; + } + match column.data_type() { + DataType::Int32 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Int32(Some(array.value(row_index))) + } + DataType::Int64 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Int64(Some(array.value(row_index))) + } + DataType::Float64 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Float64(Some(array.value(row_index))) + } + DataType::Decimal128(p, s) => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s) + } + DataType::Date32 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Date32(Some(array.value(row_index))) + } + DataType::Utf8 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Utf8(Some(array.value(row_index).to_string())) + } + other => panic!("unexpected data type in benchmark: {}", other), } - - array_value_to_string(column, row_index).unwrap() } pub async fn transform_actual_result( @@ -460,7 +479,7 @@ pub async fn transform_actual_result( Expr::Alias( Box::new(Expr::Cast(Cast::new( round, - DataType::Decimal128(38, 2), + DataType::Decimal128(15, 2), ))), Field::name(field).to_string(), ) From 365cd46bb484b5ec229c919e3c40113520384bfb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 31 Oct 2022 09:16:10 -0600 Subject: [PATCH 3/5] update comparison_binary_numeric_coercion --- benchmarks/expected-plans/q6.txt | 4 ++-- benchmarks/src/bin/tpch.rs | 1 - datafusion/expr/src/type_coercion/binary.rs | 2 ++ 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/benchmarks/expected-plans/q6.txt b/benchmarks/expected-plans/q6.txt index efc17a2724b8..0d15bd5a3bfe 100644 --- a/benchmarks/expected-plans/q6.txt +++ b/benchmarks/expected-plans/q6.txt @@ -1,6 +1,6 @@ Projection: SUM(lineitem.l_extendedprice * lineitem.l_discount) AS revenue Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice * lineitem.l_discount)]] Projection: lineitem.l_extendedprice, lineitem.l_discount - Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") AND CAST(lineitem.l_discount AS Decimal128(30, 15))lineitem.l_discount >= Decimal128(Some(49999999999999),30,15) AND CAST(lineitem.l_discount AS Decimal128(30, 15))lineitem.l_discount <= Decimal128(Some(69999999999999),30,15) AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - Projection: CAST(lineitem.l_discount AS Decimal128(30, 15)) AS CAST(lineitem.l_discount AS Decimal128(30, 15))lineitem.l_discount, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate + Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") AND CAST(lineitem.l_discount AS Float64)lineitem.l_discount >= Float64(0.049999999999999996) AND CAST(lineitem.l_discount AS Float64)lineitem.l_discount <= Float64(0.06999999999999999) AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + Projection: CAST(lineitem.l_discount AS Float64) AS CAST(lineitem.l_discount AS Float64)lineitem.l_discount, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_shipdate] \ No newline at end of file diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 948c178169c8..022191e70fa9 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -894,7 +894,6 @@ mod tests { use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::ScalarValue; use datafusion::logical_expr::expr::Cast; - use datafusion::logical_expr::Expr; use std::env; let path = env::var("TPCH_DATA").unwrap_or("benchmarks/data".to_string()); diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 45510cb03e0e..8125f667dd2e 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -226,6 +226,8 @@ fn comparison_binary_numeric_coercion( match (lhs_type, rhs_type) { // support decimal data type for comparison operation (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2), + (Decimal128(_, _), Float32 | Float64) => Some(Float64), + (Float32 | Float64, Decimal128(_, _)) => Some(Float64), (Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), (Float64, _) | (_, Float64) => Some(Float64), From 44b306b62a0194d0a6b5a1a2aaac8b8da5cbdc89 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 31 Oct 2022 09:25:19 -0600 Subject: [PATCH 4/5] revert type coercin change in comparison_binary_numeric_coercion --- benchmarks/expected-plans/q6.txt | 4 ++-- datafusion/expr/src/type_coercion/binary.rs | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/benchmarks/expected-plans/q6.txt b/benchmarks/expected-plans/q6.txt index 0d15bd5a3bfe..efc17a2724b8 100644 --- a/benchmarks/expected-plans/q6.txt +++ b/benchmarks/expected-plans/q6.txt @@ -1,6 +1,6 @@ Projection: SUM(lineitem.l_extendedprice * lineitem.l_discount) AS revenue Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice * lineitem.l_discount)]] Projection: lineitem.l_extendedprice, lineitem.l_discount - Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") AND CAST(lineitem.l_discount AS Float64)lineitem.l_discount >= Float64(0.049999999999999996) AND CAST(lineitem.l_discount AS Float64)lineitem.l_discount <= Float64(0.06999999999999999) AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - Projection: CAST(lineitem.l_discount AS Float64) AS CAST(lineitem.l_discount AS Float64)lineitem.l_discount, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate + Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") AND CAST(lineitem.l_discount AS Decimal128(30, 15))lineitem.l_discount >= Decimal128(Some(49999999999999),30,15) AND CAST(lineitem.l_discount AS Decimal128(30, 15))lineitem.l_discount <= Decimal128(Some(69999999999999),30,15) AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + Projection: CAST(lineitem.l_discount AS Decimal128(30, 15)) AS CAST(lineitem.l_discount AS Decimal128(30, 15))lineitem.l_discount, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_shipdate] \ No newline at end of file diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 8125f667dd2e..45510cb03e0e 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -226,8 +226,6 @@ fn comparison_binary_numeric_coercion( match (lhs_type, rhs_type) { // support decimal data type for comparison operation (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2), - (Decimal128(_, _), Float32 | Float64) => Some(Float64), - (Float32 | Float64, Decimal128(_, _)) => Some(Float64), (Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), (Float64, _) | (_, Float64) => Some(Float64), From 5f4d13f9617ce7f03524a362273f57884a2dbf33 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 31 Oct 2022 10:10:59 -0600 Subject: [PATCH 5/5] smaller tolerance --- benchmarks/src/bin/tpch.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 022191e70fa9..df64537bd008 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -1010,7 +1010,7 @@ mod tests { match (&expected_row[j], &actual_row[j]) { (ScalarValue::Float64(Some(l)), ScalarValue::Float64(Some(r))) => { // allow for rounding errors until we move to decimal types - let tolerance = 1.0; + let tolerance = 0.1; if (l - r).abs() > tolerance { panic!( "Expected: {}; Actual: {}; Tolerance: {}",