diff --git a/benchmarks/README.md b/benchmarks/README.md index 7b4dd3001060..505469fc5ea7 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -25,7 +25,8 @@ implementations as well as other query engines. ## Benchmark derived from TPC-H -These benchmarks are derived from the [TPC-H][1] benchmark. +These benchmarks are derived from the [TPC-H][1] benchmark. And we use this repo as the source of tpch-gen and answers: +https://github.com/databricks/tpch-dbgen.git, based on [2.17.1](https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf) version of TPC-H. ## Generating Test Data diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 43db654e83f9..963833ee9aed 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -197,8 +197,21 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result Schema { Field::new("p_type", DataType::Utf8, false), Field::new("p_size", DataType::Int32, false), Field::new("p_container", DataType::Utf8, false), - Field::new("p_retailprice", DataType::Float64, false), + Field::new("p_retailprice", DataType::Decimal128(15, 2), false), Field::new("p_comment", DataType::Utf8, false), ]), @@ -452,7 +466,7 @@ fn get_schema(table: &str) -> Schema { Field::new("s_address", DataType::Utf8, false), Field::new("s_nationkey", DataType::Int64, false), Field::new("s_phone", DataType::Utf8, false), - Field::new("s_acctbal", DataType::Float64, false), + Field::new("s_acctbal", DataType::Decimal128(15, 2), false), Field::new("s_comment", DataType::Utf8, false), ]), @@ -460,7 +474,7 @@ fn get_schema(table: &str) -> Schema { Field::new("ps_partkey", DataType::Int64, false), Field::new("ps_suppkey", DataType::Int64, false), Field::new("ps_availqty", DataType::Int32, false), - Field::new("ps_supplycost", DataType::Float64, false), + Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), Field::new("ps_comment", DataType::Utf8, false), ]), @@ -470,7 +484,7 @@ fn get_schema(table: &str) -> Schema { Field::new("c_address", DataType::Utf8, false), Field::new("c_nationkey", DataType::Int64, false), Field::new("c_phone", DataType::Utf8, false), - Field::new("c_acctbal", DataType::Float64, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), Field::new("c_mktsegment", DataType::Utf8, false), Field::new("c_comment", DataType::Utf8, false), ]), @@ -479,7 +493,7 @@ fn get_schema(table: &str) -> Schema { Field::new("o_orderkey", DataType::Int64, false), Field::new("o_custkey", DataType::Int64, false), Field::new("o_orderstatus", DataType::Utf8, false), - Field::new("o_totalprice", DataType::Float64, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), Field::new("o_orderdate", DataType::Date32, false), Field::new("o_orderpriority", DataType::Utf8, false), Field::new("o_clerk", DataType::Utf8, false), @@ -492,10 +506,10 @@ fn get_schema(table: &str) -> Schema { Field::new("l_partkey", DataType::Int64, false), Field::new("l_suppkey", DataType::Int64, false), Field::new("l_linenumber", DataType::Int32, false), - Field::new("l_quantity", DataType::Float64, false), - Field::new("l_extendedprice", DataType::Float64, false), - Field::new("l_discount", DataType::Float64, false), - Field::new("l_tax", DataType::Float64, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), Field::new("l_returnflag", DataType::Utf8, false), Field::new("l_linestatus", DataType::Utf8, false), Field::new("l_shipdate", DataType::Date32, false), @@ -575,12 +589,39 @@ struct QueryResult { mod tests { use super::*; use std::env; + use std::ops::{Div, Mul}; use std::sync::Arc; use datafusion::arrow::array::*; use datafusion::arrow::util::display::array_value_to_string; - use datafusion::logical_plan::Expr; - use datafusion::logical_plan::Expr::Cast; + use datafusion::logical_expr::Expr; + use datafusion::logical_expr::Expr::Cast; + use datafusion::logical_expr::Expr::ScalarFunction; + + const QUERY_LIMIT: [Option; 22] = [ + None, + Some(100), + Some(10), + None, + None, + None, + None, + None, + None, + Some(20), + None, + None, + None, + None, + None, + None, + None, + Some(100), + None, + None, + Some(100), + None, + ]; #[tokio::test] async fn q1() -> Result<()> { @@ -672,6 +713,7 @@ mod tests { verify_query(18).await } + #[ignore] #[tokio::test] async fn q19() -> Result<()> { verify_query(19).await @@ -762,7 +804,6 @@ mod tests { run_query(14).await } - #[ignore] // https://github.com/apache/arrow-datafusion/issues/166 #[tokio::test] async fn run_q15() -> Result<()> { run_query(15).await @@ -794,7 +835,6 @@ mod tests { run_query(20).await } - #[ignore] // https://github.com/apache/arrow-datafusion/issues/172 #[tokio::test] async fn run_q21() -> Result<()> { run_query(21).await @@ -836,21 +876,21 @@ mod tests { 1 => Schema::new(vec![ Field::new("l_returnflag", DataType::Utf8, true), Field::new("l_linestatus", DataType::Utf8, true), - Field::new("sum_qty", DataType::Float64, true), - Field::new("sum_base_price", DataType::Float64, true), - Field::new("sum_disc_price", DataType::Float64, true), - Field::new("sum_charge", DataType::Float64, true), - Field::new("avg_qty", DataType::Float64, true), - Field::new("avg_price", DataType::Float64, true), - Field::new("avg_disc", DataType::Float64, true), - Field::new("count_order", DataType::UInt64, true), + Field::new("sum_qty", DataType::Decimal128(15, 2), true), + Field::new("sum_base_price", DataType::Decimal128(15, 2), true), + Field::new("sum_disc_price", DataType::Decimal128(15, 2), true), + Field::new("sum_charge", DataType::Decimal128(15, 2), true), + Field::new("avg_qty", DataType::Decimal128(15, 2), true), + Field::new("avg_price", DataType::Decimal128(15, 2), true), + Field::new("avg_disc", DataType::Decimal128(15, 2), true), + Field::new("count_order", DataType::Int64, true), ]), 2 => Schema::new(vec![ - Field::new("s_acctbal", DataType::Float64, true), + Field::new("s_acctbal", DataType::Decimal128(15, 2), true), Field::new("s_name", DataType::Utf8, true), Field::new("n_name", DataType::Utf8, true), - Field::new("p_partkey", DataType::Int32, true), + Field::new("p_partkey", DataType::Int64, true), Field::new("p_mfgr", DataType::Utf8, true), Field::new("s_address", DataType::Utf8, true), Field::new("s_phone", DataType::Utf8, true), @@ -858,47 +898,51 @@ mod tests { ]), 3 => Schema::new(vec![ - Field::new("l_orderkey", DataType::Int32, true), - Field::new("revenue", DataType::Float64, true), + Field::new("l_orderkey", DataType::Int64, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), Field::new("o_orderdate", DataType::Date32, true), Field::new("o_shippriority", DataType::Int32, true), ]), 4 => Schema::new(vec![ Field::new("o_orderpriority", DataType::Utf8, true), - Field::new("order_count", DataType::Int32, true), + Field::new("order_count", DataType::Int64, true), ]), 5 => Schema::new(vec![ Field::new("n_name", DataType::Utf8, true), - Field::new("revenue", DataType::Float64, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), ]), - 6 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]), + 6 => Schema::new(vec![Field::new( + "revenue", + DataType::Decimal128(15, 2), + true, + )]), 7 => Schema::new(vec![ Field::new("supp_nation", DataType::Utf8, true), Field::new("cust_nation", DataType::Utf8, true), Field::new("l_year", DataType::Int32, true), - Field::new("revenue", DataType::Float64, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), ]), 8 => Schema::new(vec![ Field::new("o_year", DataType::Int32, true), - Field::new("mkt_share", DataType::Float64, true), + Field::new("mkt_share", DataType::Decimal128(15, 2), true), ]), 9 => Schema::new(vec![ Field::new("nation", DataType::Utf8, true), Field::new("o_year", DataType::Int32, true), - Field::new("sum_profit", DataType::Float64, true), + Field::new("sum_profit", DataType::Decimal128(15, 2), true), ]), 10 => Schema::new(vec![ - Field::new("c_custkey", DataType::Int32, true), + Field::new("c_custkey", DataType::Int64, true), Field::new("c_name", DataType::Utf8, true), - Field::new("revenue", DataType::Float64, true), - Field::new("c_acctbal", DataType::Float64, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), + Field::new("c_acctbal", DataType::Decimal128(15, 2), true), Field::new("n_name", DataType::Utf8, true), Field::new("c_address", DataType::Utf8, true), Field::new("c_phone", DataType::Utf8, true), @@ -906,8 +950,8 @@ mod tests { ]), 11 => Schema::new(vec![ - Field::new("ps_partkey", DataType::Int32, true), - Field::new("value", DataType::Float64, true), + Field::new("ps_partkey", DataType::Int64, true), + Field::new("value", DataType::Decimal128(15, 2), true), ]), 12 => Schema::new(vec![ @@ -923,24 +967,30 @@ mod tests { 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), - 15 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), + 15 => Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, true), + Field::new("s_name", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + Field::new("s_phone", DataType::Utf8, true), + Field::new("total_revenue", DataType::Decimal128(15, 2), true), + ]), 16 => Schema::new(vec![ Field::new("p_brand", DataType::Utf8, true), Field::new("p_type", DataType::Utf8, true), - Field::new("c_phone", DataType::Int32, true), - Field::new("c_comment", DataType::Int32, true), + Field::new("p_size", DataType::Int32, true), + Field::new("supplier_cnt", DataType::Int64, true), ]), 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), 18 => Schema::new(vec![ Field::new("c_name", DataType::Utf8, true), - Field::new("c_custkey", DataType::Int32, true), - Field::new("o_orderkey", DataType::Int32, true), + Field::new("c_custkey", DataType::Int64, true), + Field::new("o_orderkey", DataType::Int64, true), Field::new("o_orderdate", DataType::Date32, true), - Field::new("o_totalprice", DataType::Float64, true), - Field::new("sum_l_quantity", DataType::Float64, true), + Field::new("o_totalprice", DataType::Decimal128(15, 2), true), + Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true), ]), 19 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]), @@ -952,13 +1002,13 @@ mod tests { 21 => Schema::new(vec![ Field::new("s_name", DataType::Utf8, true), - Field::new("numwait", DataType::Int32, true), + Field::new("numwait", DataType::Int64, true), ]), 22 => Schema::new(vec![ - Field::new("cntrycode", DataType::Int32, true), - Field::new("numcust", DataType::Int32, true), - Field::new("totacctbal", DataType::Float64, true), + Field::new("cntrycode", DataType::Utf8, true), + Field::new("numcust", DataType::Int64, true), + Field::new("totacctbal", DataType::Decimal128(15, 2), true), ]), _ => unimplemented!(), @@ -983,22 +1033,59 @@ mod tests { ) } - // convert the schema to the same but with all columns set to nullable=true. - // this allows direct schema comparison ignoring nullable. - fn nullable_schema(schema: Arc) -> Schema { - Schema::new( - schema - .fields() - .iter() - .map(|field| { - Field::new( - Field::name(field), - Field::data_type(field).to_owned(), - true, - ) - }) - .collect::>(), - ) + async fn transform_actual_result( + result: Vec, + n: usize, + ) -> Result> { + // to compare the recorded answers to the answers we got back from running the query, + // we need to round the decimal columns and trim the Utf8 columns + let ctx = SessionContext::new(); + let result_schema = result[0].schema(); + let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?); + let mut df = ctx.read_table(table)? + .select( + result_schema + .fields + .iter() + .map(|field| { + match Field::data_type(field) { + DataType::Decimal128(_,_) => { + // if decimal, then round it to 2 decimal places like the answers + // round() doesn't support the second argument for decimal places to round to + // this can be simplified to remove the mul and div when + // https://github.com/apache/arrow-datafusion/issues/2420 is completed + // cast it back to an over-sized Decimal with 2 precision when done rounding + let round = Box::new(ScalarFunction { + fun: datafusion::logical_expr::BuiltinScalarFunction::Round, + args: vec![col(Field::name(field)).mul(lit(100))] + }.div(lit(100))); + Expr::Alias( + Box::new(Cast { + expr: round, + data_type: DataType::Decimal128(38,2), + }), + Field::name(field).to_string(), + ) + } + DataType::Utf8 => { + // if string, then trim it like the answers got trimmed + Expr::Alias( + Box::new(trim(col(Field::name(field)))), + Field::name(field).to_string() + ) + } + _ => { + col(Field::name(field)) + } + } + }).collect() + )?; + if let Some(x) = QUERY_LIMIT[n - 1] { + df = df.limit(0, Some(x))?; + } + + let df = df.collect().await?; + Ok(df) } async fn run_query(n: usize) -> Result<()> { @@ -1026,6 +1113,11 @@ mod tests { Ok(()) } + /// compares query results against stored answers from the git repo + /// verifies that: + /// * datatypes returned in columns is correct + /// * the correct number of rows are returned + /// * the content of the rows is correct async fn verify_query(n: usize) -> Result<()> { if let Ok(path) = env::var("TPCH_DATA") { // load expected answers from tpch-dbgen @@ -1045,13 +1137,30 @@ mod tests { .fields() .iter() .map(|field| { - Expr::Alias( - Box::new(Cast { - expr: Box::new(trim(col(Field::name(field)))), - data_type: Field::data_type(field).to_owned(), - }), - Field::name(field).to_string(), - ) + match Field::data_type(field) { + DataType::Decimal128(_, _) => { + // there's no support for casting from Utf8 to Decimal, so + // we'll cast from Utf8 to Float64 to Decimal for Decimal types + let inner_cast = Box::new(Cast { + expr: Box::new(trim(col(Field::name(field)))), + data_type: DataType::Float64, + }); + Expr::Alias( + Box::new(Cast { + expr: inner_cast, + data_type: Field::data_type(field).to_owned(), + }), + Field::name(field).to_string(), + ) + } + _ => Expr::Alias( + Box::new(Cast { + expr: Box::new(trim(col(Field::name(field)))), + data_type: Field::data_type(field).to_owned(), + }), + Field::name(field).to_string(), + ), + } }) .collect::>(), )?; @@ -1071,20 +1180,30 @@ mod tests { }; let actual = benchmark_datafusion(opt).await?; - // assert schema equality without comparing nullable values - assert_eq!( - nullable_schema(expected[0].schema()), - nullable_schema(actual[0].schema()) - ); + let transformed = transform_actual_result(actual, n).await?; + + // assert schema data types match + let transformed_fields = &transformed[0].schema().fields; + let expected_fields = &expected[0].schema().fields; + let schema_matches = transformed_fields + .iter() + .zip(expected_fields.iter()) + .all(|(t, e)| match t.data_type() { + DataType::Decimal128(_, _) => { + matches!(e.data_type(), DataType::Decimal128(_, _)) + } + data_type => data_type == e.data_type(), + }); + assert!(schema_matches); // convert both datasets to Vec> for simple comparison let expected_vec = result_vec(&expected); - let actual_vec = result_vec(&actual); + let actual_vec = result_vec(&transformed); // basic result comparison assert_eq!(expected_vec.len(), actual_vec.len()); - // compare each row. this works as all TPC-H queries have determinisically ordered results + // compare each row. this works as all TPC-H queries have deterministically ordered results for i in 0..actual_vec.len() { assert_eq!(expected_vec[i], actual_vec[i]); }