diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index b9afe4d6a17e..df64537bd008 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -668,7 +668,6 @@ mod tests { } #[cfg(feature = "ci")] - #[ignore] // TODO produces correct result but has rounding error #[tokio::test] async fn verify_q9() -> Result<()> { verify_query(9).await @@ -681,7 +680,6 @@ mod tests { } #[cfg(feature = "ci")] - #[ignore] // https://github.com/apache/arrow-datafusion/issues/4023 #[tokio::test] async fn verify_q11() -> Result<()> { verify_query(11).await @@ -700,7 +698,6 @@ mod tests { } #[cfg(feature = "ci")] - #[ignore] // https://github.com/apache/arrow-datafusion/issues/4025 #[tokio::test] async fn verify_q14() -> Result<()> { verify_query(14).await @@ -719,7 +716,6 @@ mod tests { } #[cfg(feature = "ci")] - #[ignore] // https://github.com/apache/arrow-datafusion/issues/4026 #[tokio::test] async fn verify_q17() -> Result<()> { verify_query(17).await @@ -896,8 +892,8 @@ mod tests { #[cfg(feature = "ci")] async fn verify_query(n: usize) -> Result<()> { use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::ScalarValue; use datafusion::logical_expr::expr::Cast; - use datafusion::logical_expr::Expr; use std::env; let path = env::var("TPCH_DATA").unwrap_or("benchmarks/data".to_string()); @@ -990,7 +986,12 @@ mod tests { } data_type => data_type == e.data_type(), }); - assert!(schema_matches); + if !schema_matches { + panic!( + "expected_fields: {:?}\ntransformed_fields: {:?}", + expected_fields, transformed_fields + ) + } // convert both datasets to Vec> for simple comparison let expected_vec = result_vec(&expected); @@ -1000,8 +1001,26 @@ mod tests { assert_eq!(expected_vec.len(), actual_vec.len()); // compare each row. this works as all TPC-H queries have deterministically ordered results - for i in 0..actual_vec.len() { - assert_eq!(expected_vec[i], actual_vec[i]); + for i in 0..expected_vec.len() { + let expected_row = &expected_vec[i]; + let actual_row = &actual_vec[i]; + assert_eq!(expected_row.len(), actual_row.len()); + + for j in 0..expected.len() { + match (&expected_row[j], &actual_row[j]) { + (ScalarValue::Float64(Some(l)), ScalarValue::Float64(Some(r))) => { + // allow for rounding errors until we move to decimal types + let tolerance = 0.1; + if (l - r).abs() > tolerance { + panic!( + "Expected: {}; Actual: {}; Tolerance: {}", + l, r, tolerance + ) + } + } + (l, r) => assert_eq!(format!("{:?}", l), format!("{:?}", r)), + } + } } Ok(()) diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs index 46c53edf120e..ad61de8a3b0b 100644 --- a/benchmarks/src/tpch.rs +++ b/benchmarks/src/tpch.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::ArrayRef; +use arrow::array::{ + Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, Int64Array, + StringArray, +}; use arrow::record_batch::RecordBatch; use std::fs; use std::ops::{Div, Mul}; @@ -23,7 +26,7 @@ use std::path::Path; use std::sync::Arc; use std::time::Instant; -use datafusion::arrow::util::display::array_value_to_string; +use datafusion::common::ScalarValue; use datafusion::logical_expr::Cast; use datafusion::prelude::*; use datafusion::{ @@ -229,11 +232,7 @@ pub fn get_answer_schema(n: usize) -> Schema { Field::new("custdist", DataType::Int64, true), ]), - 14 => Schema::new(vec![Field::new( - "promo_revenue", - DataType::Decimal128(38, 2), - true, - )]), + 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), 15 => Schema::new(vec![ Field::new("s_suppkey", DataType::Int64, true), @@ -250,11 +249,7 @@ pub fn get_answer_schema(n: usize) -> Schema { Field::new("supplier_cnt", DataType::Int64, true), ]), - 17 => Schema::new(vec![Field::new( - "avg_yearly", - DataType::Decimal128(38, 2), - true, - )]), + 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), 18 => Schema::new(vec![ Field::new("c_name", DataType::Utf8, true), @@ -389,14 +384,14 @@ pub async fn convert_tbl( /// Converts the results into a 2d array of strings, `result[row][column]` /// Special cases nulls to NULL for testing -pub fn result_vec(results: &[RecordBatch]) -> Vec> { +pub fn result_vec(results: &[RecordBatch]) -> Vec> { let mut result = vec![]; for batch in results { for row_index in 0..batch.num_rows() { let row_vec = batch .columns() .iter() - .map(|column| col_str(column, row_index)) + .map(|column| col_to_scalar(column, row_index)) .collect(); result.push(row_vec); } @@ -422,13 +417,37 @@ pub fn string_schema(schema: Schema) -> Schema { ) } -/// Specialised String representation -fn col_str(column: &ArrayRef, row_index: usize) -> String { +fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue { if column.is_null(row_index) { - return "NULL".to_string(); + return ScalarValue::Null; + } + match column.data_type() { + DataType::Int32 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Int32(Some(array.value(row_index))) + } + DataType::Int64 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Int64(Some(array.value(row_index))) + } + DataType::Float64 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Float64(Some(array.value(row_index))) + } + DataType::Decimal128(p, s) => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s) + } + DataType::Date32 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Date32(Some(array.value(row_index))) + } + DataType::Utf8 => { + let array = column.as_any().downcast_ref::().unwrap(); + ScalarValue::Utf8(Some(array.value(row_index).to_string())) + } + other => panic!("unexpected data type in benchmark: {}", other), } - - array_value_to_string(column, row_index).unwrap() } pub async fn transform_actual_result( @@ -460,7 +479,7 @@ pub async fn transform_actual_result( Expr::Alias( Box::new(Expr::Cast(Cast::new( round, - DataType::Decimal128(38, 2), + DataType::Decimal128(15, 2), ))), Field::name(field).to_string(), )