diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4c6daf7554b7..ebb1c07a47f9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -117,6 +117,47 @@ jobs: - name: Verify Working Directory Clean run: git diff --exit-code + # verify that the benchmark queries return the correct results + verify-benchmark-results: + name: verify benchmark results (amd64) + needs: [linux-build-lib] + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Cache Cargo + uses: actions/cache@v3 + with: + path: /github/home/.cargo + # this key equals the ones on `linux-build-lib` for re-use + key: cargo-cache- + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Generate benchmark data and expected query results + run: | + mkdir -p benchmarks/data/answers + git clone https://github.com/databricks/tpch-dbgen.git + cd tpch-dbgen + make + ./dbgen -f -s 1 + mv *.tbl ../benchmarks/data + mv ./answers/* ../benchmarks/data/answers/ + - name: Verify that benchmark queries return expected results + run: | + export TPCH_DATA=`pwd`/benchmarks/data + cargo test verify_q --profile release-nonlto --features=ci -- --test-threads=1 + - name: Verify Working Directory Clean + run: git diff --exit-code + integration-test: name: "Compare to postgres" needs: [linux-build-lib] diff --git a/Cargo.toml b/Cargo.toml index ab3f427e49be..36a9405b0fbe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,21 +16,24 @@ # under the License. [workspace] -members = [ - "datafusion/common", - "datafusion/core", - "datafusion/expr", - "datafusion/jit", - "datafusion/optimizer", - "datafusion/physical-expr", - "datafusion/proto", - "datafusion/row", - "datafusion/sql", - "datafusion-examples", - "benchmarks", -] exclude = ["datafusion-cli"] +members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/jit", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/proto", "datafusion/row", "datafusion/sql", "datafusion-examples", "benchmarks", +] [profile.release] codegen-units = 1 lto = true + +# the release profile takes a long time to build so we can use this profile during development to save time +# cargo build --profile release-nonlto +[profile.release-nonlto] +codegen-units = 16 +debug = false +debug-assertions = false +incremental = false +inherits = "release" +lto = false +opt-level = 3 +overflow-checks = false +panic = 'unwind' +rpath = false diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 7105a6033693..8795a8611193 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -24,10 +24,10 @@ authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" license = "Apache-2.0" -publish = false rust-version = "1.62" [features] +ci = [] default = ["mimalloc"] simd = ["datafusion/simd"] snmalloc = ["snmalloc-rs"] diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 02de551f2808..b9afe4d6a17e 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -18,7 +18,7 @@ //! Benchmark derived from TPC-H. This is not an official TPC-H benchmark. use std::{ - fs::{self, File}, + fs::File, io::Write, iter::Iterator, path::{Path, PathBuf}, @@ -29,15 +29,9 @@ use std::{ use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::{DataFusionError, Result}; use datafusion::parquet::basic::Compression; -use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; -use datafusion::{ - arrow::datatypes::{DataType, Field, Schema}, - datasource::file_format::{csv::CsvFormat, FileFormat}, - DATAFUSION_VERSION, -}; use datafusion::{ arrow::record_batch::RecordBatch, datasource::file_format::parquet::ParquetFormat, }; @@ -45,6 +39,11 @@ use datafusion::{ arrow::util::pretty, datasource::listing::{ListingOptions, ListingTable, ListingTableConfig}, }; +use datafusion::{ + datasource::file_format::{csv::CsvFormat, FileFormat}, + DATAFUSION_VERSION, +}; +use datafusion_benchmarks::tpch::*; use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION; use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; @@ -145,10 +144,6 @@ enum TpchOpt { Convert(ConvertOpt), } -const TABLES: &[&str] = &[ - "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", -]; - #[tokio::main] async fn main() -> Result<()> { use BenchmarkSubCommandOpt::*; @@ -158,7 +153,32 @@ async fn main() -> Result<()> { TpchOpt::Benchmark(DataFusionBenchmark(opt)) => { benchmark_datafusion(opt).await.map(|_| ()) } - TpchOpt::Convert(opt) => convert_tbl(opt).await, + TpchOpt::Convert(opt) => { + let compression = match opt.compression.as_str() { + "none" => Compression::UNCOMPRESSED, + "snappy" => Compression::SNAPPY, + "brotli" => Compression::BROTLI, + "gzip" => Compression::GZIP, + "lz4" => Compression::LZ4, + "lz0" => Compression::LZO, + "zstd" => Compression::ZSTD, + other => { + return Err(DataFusionError::NotImplemented(format!( + "Invalid compression format: {}", + other + ))); + } + }; + convert_tbl( + opt.input_path.to_str().unwrap(), + opt.output_path.to_str().unwrap(), + &opt.file_format, + opt.partitions, + opt.batch_size, + compression, + ) + .await + } } } @@ -173,7 +193,7 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result Result<( Ok(()) } -/// Get the SQL statements from the specified query file -fn get_query_sql(query: usize) -> Result> { - if query > 0 && query < 23 { - let possibilities = vec![ - format!("queries/q{}.sql", query), - format!("benchmarks/queries/q{}.sql", query), - ]; - let mut errors = vec![]; - for filename in possibilities { - match fs::read_to_string(&filename) { - Ok(contents) => { - return Ok(contents - .split(';') - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .map(|s| s.to_string()) - .collect()); - } - Err(e) => errors.push(format!("{}: {}", filename, e)), - }; - } - Err(DataFusionError::Plan(format!( - "invalid query. Could not find query: {:?}", - errors - ))) - } else { - Err(DataFusionError::Plan( - "invalid query. Expected value between 1 and 22".to_owned(), - )) - } -} - async fn execute_query( ctx: &SessionContext, sql: &str, @@ -335,77 +323,6 @@ async fn execute_query( Ok(result) } -async fn convert_tbl(opt: ConvertOpt) -> Result<()> { - let output_root_path = Path::new(&opt.output_path); - for table in TABLES { - let start = Instant::now(); - let schema = get_schema(table); - - let input_path = format!("{}/{}.tbl", opt.input_path.to_str().unwrap(), table); - let options = CsvReadOptions::new() - .schema(&schema) - .has_header(false) - .delimiter(b'|') - .file_extension(".tbl"); - - let config = SessionConfig::new().with_batch_size(opt.batch_size); - let ctx = SessionContext::with_config(config); - - // build plan to read the TBL file - let mut csv = ctx.read_csv(&input_path, options).await?; - - // optionally, repartition the file - if opt.partitions > 1 { - csv = csv.repartition(Partitioning::RoundRobinBatch(opt.partitions))? - } - - // create the physical plan - let csv = csv.to_logical_plan()?; - let csv = ctx.create_physical_plan(&csv).await?; - - let output_path = output_root_path.join(table); - let output_path = output_path.to_str().unwrap().to_owned(); - - println!( - "Converting '{}' to {} files in directory '{}'", - &input_path, &opt.file_format, &output_path - ); - match opt.file_format.as_str() { - "csv" => ctx.write_csv(csv, output_path).await?, - "parquet" => { - let compression = match opt.compression.as_str() { - "none" => Compression::UNCOMPRESSED, - "snappy" => Compression::SNAPPY, - "brotli" => Compression::BROTLI, - "gzip" => Compression::GZIP, - "lz4" => Compression::LZ4, - "lz0" => Compression::LZO, - "zstd" => Compression::ZSTD, - other => { - return Err(DataFusionError::NotImplemented(format!( - "Invalid compression format: {}", - other - ))); - } - }; - let props = WriterProperties::builder() - .set_compression(compression) - .build(); - ctx.write_parquet(csv, output_path, Some(props)).await? - } - other => { - return Err(DataFusionError::NotImplemented(format!( - "Invalid output format: {}", - other - ))); - } - } - println!("Conversion completed in {} ms", start.elapsed().as_millis()); - } - - Ok(()) -} - async fn get_table( ctx: &mut SessionState, path: &str, @@ -443,7 +360,7 @@ async fn get_table( unimplemented!("Invalid file format '{}'", other); } }; - let schema = Arc::new(get_schema(table)); + let schema = Arc::new(get_tpch_table_schema(table)); let options = ListingOptions { format, @@ -465,101 +382,6 @@ async fn get_table( Ok(Arc::new(ListingTable::try_new(config)?)) } -fn get_schema(table: &str) -> Schema { - // note that the schema intentionally uses signed integers so that any generated Parquet - // files can also be used to benchmark tools that only support signed integers, such as - // Apache Spark - - match table { - "part" => Schema::new(vec![ - Field::new("p_partkey", DataType::Int64, false), - Field::new("p_name", DataType::Utf8, false), - Field::new("p_mfgr", DataType::Utf8, false), - Field::new("p_brand", DataType::Utf8, false), - Field::new("p_type", DataType::Utf8, false), - Field::new("p_size", DataType::Int32, false), - Field::new("p_container", DataType::Utf8, false), - Field::new("p_retailprice", DataType::Decimal128(15, 2), false), - Field::new("p_comment", DataType::Utf8, false), - ]), - - "supplier" => Schema::new(vec![ - Field::new("s_suppkey", DataType::Int64, false), - Field::new("s_name", DataType::Utf8, false), - Field::new("s_address", DataType::Utf8, false), - Field::new("s_nationkey", DataType::Int64, false), - Field::new("s_phone", DataType::Utf8, false), - Field::new("s_acctbal", DataType::Decimal128(15, 2), false), - Field::new("s_comment", DataType::Utf8, false), - ]), - - "partsupp" => Schema::new(vec![ - Field::new("ps_partkey", DataType::Int64, false), - Field::new("ps_suppkey", DataType::Int64, false), - Field::new("ps_availqty", DataType::Int32, false), - Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), - Field::new("ps_comment", DataType::Utf8, false), - ]), - - "customer" => Schema::new(vec![ - Field::new("c_custkey", DataType::Int64, false), - Field::new("c_name", DataType::Utf8, false), - Field::new("c_address", DataType::Utf8, false), - Field::new("c_nationkey", DataType::Int64, false), - Field::new("c_phone", DataType::Utf8, false), - Field::new("c_acctbal", DataType::Decimal128(15, 2), false), - Field::new("c_mktsegment", DataType::Utf8, false), - Field::new("c_comment", DataType::Utf8, false), - ]), - - "orders" => Schema::new(vec![ - Field::new("o_orderkey", DataType::Int64, false), - Field::new("o_custkey", DataType::Int64, false), - Field::new("o_orderstatus", DataType::Utf8, false), - Field::new("o_totalprice", DataType::Decimal128(15, 2), false), - Field::new("o_orderdate", DataType::Date32, false), - Field::new("o_orderpriority", DataType::Utf8, false), - Field::new("o_clerk", DataType::Utf8, false), - Field::new("o_shippriority", DataType::Int32, false), - Field::new("o_comment", DataType::Utf8, false), - ]), - - "lineitem" => Schema::new(vec![ - Field::new("l_orderkey", DataType::Int64, false), - Field::new("l_partkey", DataType::Int64, false), - Field::new("l_suppkey", DataType::Int64, false), - Field::new("l_linenumber", DataType::Int32, false), - Field::new("l_quantity", DataType::Decimal128(15, 2), false), - Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), - Field::new("l_discount", DataType::Decimal128(15, 2), false), - Field::new("l_tax", DataType::Decimal128(15, 2), false), - Field::new("l_returnflag", DataType::Utf8, false), - Field::new("l_linestatus", DataType::Utf8, false), - Field::new("l_shipdate", DataType::Date32, false), - Field::new("l_commitdate", DataType::Date32, false), - Field::new("l_receiptdate", DataType::Date32, false), - Field::new("l_shipinstruct", DataType::Utf8, false), - Field::new("l_shipmode", DataType::Utf8, false), - Field::new("l_comment", DataType::Utf8, false), - ]), - - "nation" => Schema::new(vec![ - Field::new("n_nationkey", DataType::Int64, false), - Field::new("n_name", DataType::Utf8, false), - Field::new("n_regionkey", DataType::Int64, false), - Field::new("n_comment", DataType::Utf8, false), - ]), - - "region" => Schema::new(vec![ - Field::new("r_regionkey", DataType::Int64, false), - Field::new("r_name", DataType::Utf8, false), - Field::new("r_comment", DataType::Utf8, false), - ]), - - _ => unimplemented!(), - } -} - #[derive(Debug, Serialize)] struct BenchmarkRun { /// Benchmark crate version @@ -611,43 +433,10 @@ struct QueryResult { #[cfg(test)] mod tests { use super::*; - use std::env; + use datafusion::sql::TableReference; use std::io::{BufRead, BufReader}; - use std::ops::{Div, Mul}; use std::sync::Arc; - use datafusion::arrow::array::*; - use datafusion::arrow::util::display::array_value_to_string; - use datafusion::logical_expr::expr::Cast; - use datafusion::logical_expr::Expr; - use datafusion::logical_expr::Expr::ScalarFunction; - use datafusion::sql::TableReference; - - const QUERY_LIMIT: [Option; 22] = [ - None, - Some(100), - Some(10), - None, - None, - None, - None, - None, - None, - Some(20), - None, - None, - None, - None, - None, - None, - None, - Some(100), - None, - None, - Some(100), - None, - ]; - #[tokio::test] async fn q1_expected_plan() -> Result<()> { expected_plan(1).await @@ -770,9 +559,9 @@ mod tests { async fn expected_plan(query: usize) -> Result<()> { let ctx = SessionContext::new(); - for table in TABLES { + for table in TPCH_TABLES { let table = table.to_string(); - let schema = get_schema(&table); + let schema = get_tpch_table_schema(&table); let mem_table = MemTable::try_new(Arc::new(schema), vec![])?; ctx.register_table( TableReference::from(table.as_str()), @@ -829,113 +618,140 @@ mod tests { Ok(str) } + #[cfg(feature = "ci")] #[tokio::test] - async fn q1() -> Result<()> { + async fn verify_q1() -> Result<()> { verify_query(1).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q2() -> Result<()> { + async fn verify_q2() -> Result<()> { verify_query(2).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q3() -> Result<()> { + async fn verify_q3() -> Result<()> { verify_query(3).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q4() -> Result<()> { + async fn verify_q4() -> Result<()> { verify_query(4).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q5() -> Result<()> { + async fn verify_q5() -> Result<()> { verify_query(5).await } + #[cfg(feature = "ci")] + #[ignore] // https://github.com/apache/arrow-datafusion/issues/4024 #[tokio::test] - async fn q6() -> Result<()> { + async fn verify_q6() -> Result<()> { verify_query(6).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q7() -> Result<()> { + async fn verify_q7() -> Result<()> { verify_query(7).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q8() -> Result<()> { + async fn verify_q8() -> Result<()> { verify_query(8).await } + #[cfg(feature = "ci")] + #[ignore] // TODO produces correct result but has rounding error #[tokio::test] - async fn q9() -> Result<()> { + async fn verify_q9() -> Result<()> { verify_query(9).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q10() -> Result<()> { + async fn verify_q10() -> Result<()> { verify_query(10).await } + #[cfg(feature = "ci")] + #[ignore] // https://github.com/apache/arrow-datafusion/issues/4023 #[tokio::test] - async fn q11() -> Result<()> { + async fn verify_q11() -> Result<()> { verify_query(11).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q12() -> Result<()> { + async fn verify_q12() -> Result<()> { verify_query(12).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q13() -> Result<()> { + async fn verify_q13() -> Result<()> { verify_query(13).await } + #[cfg(feature = "ci")] + #[ignore] // https://github.com/apache/arrow-datafusion/issues/4025 #[tokio::test] - async fn q14() -> Result<()> { + async fn verify_q14() -> Result<()> { verify_query(14).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q15() -> Result<()> { + async fn verify_q15() -> Result<()> { verify_query(15).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q16() -> Result<()> { + async fn verify_q16() -> Result<()> { verify_query(16).await } + #[cfg(feature = "ci")] + #[ignore] // https://github.com/apache/arrow-datafusion/issues/4026 #[tokio::test] - async fn q17() -> Result<()> { + async fn verify_q17() -> Result<()> { verify_query(17).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q18() -> Result<()> { + async fn verify_q18() -> Result<()> { verify_query(18).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q19() -> Result<()> { + async fn verify_q19() -> Result<()> { verify_query(19).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q20() -> Result<()> { + async fn verify_q20() -> Result<()> { verify_query(20).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q21() -> Result<()> { + async fn verify_q21() -> Result<()> { verify_query(21).await } + #[cfg(feature = "ci")] #[tokio::test] - async fn q22() -> Result<()> { + async fn verify_q22() -> Result<()> { verify_query(22).await } @@ -1049,253 +865,6 @@ mod tests { run_query(22).await } - /// Specialised String representation - fn col_str(column: &ArrayRef, row_index: usize) -> String { - if column.is_null(row_index) { - return "NULL".to_string(); - } - - array_value_to_string(column, row_index).unwrap() - } - - /// Converts the results into a 2d array of strings, `result[row][column]` - /// Special cases nulls to NULL for testing - fn result_vec(results: &[RecordBatch]) -> Vec> { - let mut result = vec![]; - for batch in results { - for row_index in 0..batch.num_rows() { - let row_vec = batch - .columns() - .iter() - .map(|column| col_str(column, row_index)) - .collect(); - result.push(row_vec); - } - } - result - } - - fn get_answer_schema(n: usize) -> Schema { - match n { - 1 => Schema::new(vec![ - Field::new("l_returnflag", DataType::Utf8, true), - Field::new("l_linestatus", DataType::Utf8, true), - Field::new("sum_qty", DataType::Decimal128(15, 2), true), - Field::new("sum_base_price", DataType::Decimal128(15, 2), true), - Field::new("sum_disc_price", DataType::Decimal128(15, 2), true), - Field::new("sum_charge", DataType::Decimal128(15, 2), true), - Field::new("avg_qty", DataType::Decimal128(15, 2), true), - Field::new("avg_price", DataType::Decimal128(15, 2), true), - Field::new("avg_disc", DataType::Decimal128(15, 2), true), - Field::new("count_order", DataType::Int64, true), - ]), - - 2 => Schema::new(vec![ - Field::new("s_acctbal", DataType::Decimal128(15, 2), true), - Field::new("s_name", DataType::Utf8, true), - Field::new("n_name", DataType::Utf8, true), - Field::new("p_partkey", DataType::Int64, true), - Field::new("p_mfgr", DataType::Utf8, true), - Field::new("s_address", DataType::Utf8, true), - Field::new("s_phone", DataType::Utf8, true), - Field::new("s_comment", DataType::Utf8, true), - ]), - - 3 => Schema::new(vec![ - Field::new("l_orderkey", DataType::Int64, true), - Field::new("revenue", DataType::Decimal128(15, 2), true), - Field::new("o_orderdate", DataType::Date32, true), - Field::new("o_shippriority", DataType::Int32, true), - ]), - - 4 => Schema::new(vec![ - Field::new("o_orderpriority", DataType::Utf8, true), - Field::new("order_count", DataType::Int64, true), - ]), - - 5 => Schema::new(vec![ - Field::new("n_name", DataType::Utf8, true), - Field::new("revenue", DataType::Decimal128(15, 2), true), - ]), - - 6 => Schema::new(vec![Field::new( - "revenue", - DataType::Decimal128(15, 2), - true, - )]), - - 7 => Schema::new(vec![ - Field::new("supp_nation", DataType::Utf8, true), - Field::new("cust_nation", DataType::Utf8, true), - Field::new("l_year", DataType::Int32, true), - Field::new("revenue", DataType::Decimal128(15, 2), true), - ]), - - 8 => Schema::new(vec![ - Field::new("o_year", DataType::Int32, true), - Field::new("mkt_share", DataType::Decimal128(15, 2), true), - ]), - - 9 => Schema::new(vec![ - Field::new("nation", DataType::Utf8, true), - Field::new("o_year", DataType::Int32, true), - Field::new("sum_profit", DataType::Decimal128(15, 2), true), - ]), - - 10 => Schema::new(vec![ - Field::new("c_custkey", DataType::Int64, true), - Field::new("c_name", DataType::Utf8, true), - Field::new("revenue", DataType::Decimal128(15, 2), true), - Field::new("c_acctbal", DataType::Decimal128(15, 2), true), - Field::new("n_name", DataType::Utf8, true), - Field::new("c_address", DataType::Utf8, true), - Field::new("c_phone", DataType::Utf8, true), - Field::new("c_comment", DataType::Utf8, true), - ]), - - 11 => Schema::new(vec![ - Field::new("ps_partkey", DataType::Int64, true), - Field::new("value", DataType::Decimal128(15, 2), true), - ]), - - 12 => Schema::new(vec![ - Field::new("l_shipmode", DataType::Utf8, true), - Field::new("high_line_count", DataType::Int64, true), - Field::new("low_line_count", DataType::Int64, true), - ]), - - 13 => Schema::new(vec![ - Field::new("c_count", DataType::Int64, true), - Field::new("custdist", DataType::Int64, true), - ]), - - 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), - - 15 => Schema::new(vec![ - Field::new("s_suppkey", DataType::Int64, true), - Field::new("s_name", DataType::Utf8, true), - Field::new("s_address", DataType::Utf8, true), - Field::new("s_phone", DataType::Utf8, true), - Field::new("total_revenue", DataType::Decimal128(15, 2), true), - ]), - - 16 => Schema::new(vec![ - Field::new("p_brand", DataType::Utf8, true), - Field::new("p_type", DataType::Utf8, true), - Field::new("p_size", DataType::Int32, true), - Field::new("supplier_cnt", DataType::Int64, true), - ]), - - 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), - - 18 => Schema::new(vec![ - Field::new("c_name", DataType::Utf8, true), - Field::new("c_custkey", DataType::Int64, true), - Field::new("o_orderkey", DataType::Int64, true), - Field::new("o_orderdate", DataType::Date32, true), - Field::new("o_totalprice", DataType::Decimal128(15, 2), true), - Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true), - ]), - - 19 => Schema::new(vec![Field::new( - "revenue", - DataType::Decimal128(15, 2), - true, - )]), - - 20 => Schema::new(vec![ - Field::new("s_name", DataType::Utf8, true), - Field::new("s_address", DataType::Utf8, true), - ]), - - 21 => Schema::new(vec![ - Field::new("s_name", DataType::Utf8, true), - Field::new("numwait", DataType::Int64, true), - ]), - - 22 => Schema::new(vec![ - Field::new("cntrycode", DataType::Utf8, true), - Field::new("numcust", DataType::Int64, true), - Field::new("totacctbal", DataType::Decimal128(15, 2), true), - ]), - - _ => unimplemented!(), - } - } - - // convert expected schema to all utf8 so columns can be read as strings to be parsed separately - // this is due to the fact that the csv parser cannot handle leading/trailing spaces - fn string_schema(schema: Schema) -> Schema { - Schema::new( - schema - .fields() - .iter() - .map(|field| { - Field::new( - Field::name(field), - DataType::Utf8, - Field::is_nullable(field), - ) - }) - .collect::>(), - ) - } - - async fn transform_actual_result( - result: Vec, - n: usize, - ) -> Result> { - // to compare the recorded answers to the answers we got back from running the query, - // we need to round the decimal columns and trim the Utf8 columns - let ctx = SessionContext::new(); - let result_schema = result[0].schema(); - let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?); - let mut df = ctx.read_table(table)? - .select( - result_schema - .fields - .iter() - .map(|field| { - match Field::data_type(field) { - DataType::Decimal128(_, _) => { - // if decimal, then round it to 2 decimal places like the answers - // round() doesn't support the second argument for decimal places to round to - // this can be simplified to remove the mul and div when - // https://github.com/apache/arrow-datafusion/issues/2420 is completed - // cast it back to an over-sized Decimal with 2 precision when done rounding - let round = Box::new(ScalarFunction { - fun: datafusion::logical_expr::BuiltinScalarFunction::Round, - args: vec![col(Field::name(field)).mul(lit(100))], - }.div(lit(100))); - Expr::Alias( - Box::new(Expr::Cast(Cast::new( - round, - DataType::Decimal128(38, 2), - ))), - Field::name(field).to_string(), - ) - } - DataType::Utf8 => { - // if string, then trim it like the answers got trimmed - Expr::Alias( - Box::new(trim(col(Field::name(field)))), - Field::name(field).to_string(), - ) - } - _ => { - col(Field::name(field)) - } - } - }).collect() - )?; - if let Some(x) = QUERY_LIMIT[n - 1] { - df = df.limit(0, Some(x))?; - } - - let df = df.collect().await?; - Ok(df) - } - async fn run_query(n: usize) -> Result<()> { // Tests running query with empty tables, to see whether they run successfully. @@ -1304,8 +873,8 @@ mod tests { .with_batch_size(10); let ctx = SessionContext::with_config(config); - for &table in TABLES { - let schema = get_schema(table); + for &table in TPCH_TABLES { + let schema = get_tpch_table_schema(table); let batch = RecordBatch::new_empty(Arc::new(schema.to_owned())); ctx.register_batch(table, batch)?; @@ -1324,75 +893,95 @@ mod tests { /// * datatypes returned in columns is correct /// * the correct number of rows are returned /// * the content of the rows is correct + #[cfg(feature = "ci")] async fn verify_query(n: usize) -> Result<()> { - if let Ok(path) = env::var("TPCH_DATA") { - // load expected answers from tpch-dbgen - // read csv as all strings, trim and cast to expected type as the csv string - // to value parser does not handle data with leading/trailing spaces - let ctx = SessionContext::new(); - let schema = string_schema(get_answer_schema(n)); - let options = CsvReadOptions::new() - .schema(&schema) - .delimiter(b'|') - .file_extension(".out"); - let df = ctx - .read_csv(&format!("{}/answers/q{}.out", path, n), options) - .await?; - let df = df.select( - get_answer_schema(n) - .fields() - .iter() - .map(|field| { - match Field::data_type(field) { - DataType::Decimal128(_, _) => { - // there's no support for casting from Utf8 to Decimal, so - // we'll cast from Utf8 to Float64 to Decimal for Decimal types - let inner_cast = Box::new(Expr::Cast(Cast::new( - Box::new(trim(col(Field::name(field)))), - DataType::Float64, - ))); - Expr::Alias( - Box::new(Expr::Cast(Cast::new( - inner_cast, - Field::data_type(field).to_owned(), - ))), - Field::name(field).to_string(), - ) - } - _ => Expr::Alias( + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::logical_expr::expr::Cast; + use datafusion::logical_expr::Expr; + use std::env; + + let path = env::var("TPCH_DATA").unwrap_or("benchmarks/data".to_string()); + if !Path::new(&path).exists() { + return Err(DataFusionError::Execution(format!( + "Benchmark data not found (set TPCH_DATA env var to override): {}", + path + ))); + } + + let answer_file = format!("{}/answers/q{}.out", path, n); + if !Path::new(&answer_file).exists() { + return Err(DataFusionError::Execution(format!( + "Expected results not found: {}", + answer_file + ))); + } + + // load expected answers from tpch-dbgen + // read csv as all strings, trim and cast to expected type as the csv string + // to value parser does not handle data with leading/trailing spaces + let ctx = SessionContext::new(); + let schema = string_schema(get_answer_schema(n)); + let options = CsvReadOptions::new() + .schema(&schema) + .delimiter(b'|') + .file_extension(".out"); + let df = ctx.read_csv(&answer_file, options).await?; + let df = df.select( + get_answer_schema(n) + .fields() + .iter() + .map(|field| { + match Field::data_type(field) { + DataType::Decimal128(_, _) => { + // there's no support for casting from Utf8 to Decimal, so + // we'll cast from Utf8 to Float64 to Decimal for Decimal types + let inner_cast = Box::new(Expr::Cast(Cast::new( + Box::new(trim(col(Field::name(field)))), + DataType::Float64, + ))); + Expr::Alias( Box::new(Expr::Cast(Cast::new( - Box::new(trim(col(Field::name(field)))), + inner_cast, Field::data_type(field).to_owned(), ))), Field::name(field).to_string(), - ), + ) } - }) - .collect::>(), - )?; - let expected = df.collect().await?; - - // run the query to compute actual results of the query - let opt = DataFusionBenchmarkOpt { - query: n, - debug: false, - iterations: 1, - partitions: 2, - batch_size: 8192, - path: PathBuf::from(path.to_string()), - file_format: "tbl".to_string(), - mem_table: false, - output_path: None, - disable_statistics: false, - }; - let actual = benchmark_datafusion(opt).await?; + _ => Expr::Alias( + Box::new(Expr::Cast(Cast::new( + Box::new(trim(col(Field::name(field)))), + Field::data_type(field).to_owned(), + ))), + Field::name(field).to_string(), + ), + } + }) + .collect::>(), + )?; + let expected = df.collect().await?; + + // run the query to compute actual results of the query + let opt = DataFusionBenchmarkOpt { + query: n, + debug: false, + iterations: 1, + partitions: 2, + batch_size: 8192, + path: PathBuf::from(path.to_string()), + file_format: "tbl".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + }; + let actual = benchmark_datafusion(opt).await?; - let transformed = transform_actual_result(actual, n).await?; + let transformed = transform_actual_result(actual, n).await?; - // assert schema data types match - let transformed_fields = &transformed[0].schema().fields; - let expected_fields = &expected[0].schema().fields; - let schema_matches = transformed_fields + // assert schema data types match + let transformed_fields = &transformed[0].schema().fields; + let expected_fields = &expected[0].schema().fields; + let schema_matches = + transformed_fields .iter() .zip(expected_fields.iter()) .all(|(t, e)| match t.data_type() { @@ -1401,21 +990,18 @@ mod tests { } data_type => data_type == e.data_type(), }); - assert!(schema_matches); + assert!(schema_matches); - // convert both datasets to Vec> for simple comparison - let expected_vec = result_vec(&expected); - let actual_vec = result_vec(&transformed); + // convert both datasets to Vec> for simple comparison + let expected_vec = result_vec(&expected); + let actual_vec = result_vec(&transformed); - // basic result comparison - assert_eq!(expected_vec.len(), actual_vec.len()); + // basic result comparison + assert_eq!(expected_vec.len(), actual_vec.len()); - // compare each row. this works as all TPC-H queries have deterministically ordered results - for i in 0..actual_vec.len() { - assert_eq!(expected_vec[i], actual_vec[i]); - } - } else { - println!("TPCH_DATA environment variable not set, skipping test"); + // compare each row. this works as all TPC-H queries have deterministically ordered results + for i in 0..actual_vec.len() { + assert_eq!(expected_vec[i], actual_vec[i]); } Ok(()) diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs new file mode 100644 index 000000000000..af1dd46fd42e --- /dev/null +++ b/benchmarks/src/lib.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod tpch; diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs new file mode 100644 index 000000000000..46c53edf120e --- /dev/null +++ b/benchmarks/src/tpch.rs @@ -0,0 +1,512 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::record_batch::RecordBatch; +use std::fs; +use std::ops::{Div, Mul}; +use std::path::Path; +use std::sync::Arc; +use std::time::Instant; + +use datafusion::arrow::util::display::array_value_to_string; +use datafusion::logical_expr::Cast; +use datafusion::prelude::*; +use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + datasource::MemTable, + error::{DataFusionError, Result}, +}; +use parquet::basic::Compression; +use parquet::file::properties::WriterProperties; + +pub const TPCH_TABLES: &[&str] = &[ + "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", +]; + +/// Get the schema for the benchmarks derived from TPC-H +pub fn get_tpch_table_schema(table: &str) -> Schema { + // note that the schema intentionally uses signed integers so that any generated Parquet + // files can also be used to benchmark tools that only support signed integers, such as + // Apache Spark + + match table { + "part" => Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Decimal128(15, 2), false), + Field::new("p_comment", DataType::Utf8, false), + ]), + + "supplier" => Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Decimal128(15, 2), false), + Field::new("s_comment", DataType::Utf8, false), + ]), + + "partsupp" => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), + Field::new("ps_comment", DataType::Utf8, false), + ]), + + "customer" => Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::Int64, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ]), + + "orders" => Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), + Field::new("o_orderdate", DataType::Date32, false), + Field::new("o_orderpriority", DataType::Utf8, false), + Field::new("o_clerk", DataType::Utf8, false), + Field::new("o_shippriority", DataType::Int32, false), + Field::new("o_comment", DataType::Utf8, false), + ]), + + "lineitem" => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + Field::new("l_comment", DataType::Utf8, false), + ]), + + "nation" => Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, false), + Field::new("n_regionkey", DataType::Int64, false), + Field::new("n_comment", DataType::Utf8, false), + ]), + + "region" => Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, false), + Field::new("r_comment", DataType::Utf8, false), + ]), + + _ => unimplemented!(), + } +} + +/// Get the expected schema for the results of a query +pub fn get_answer_schema(n: usize) -> Schema { + match n { + 1 => Schema::new(vec![ + Field::new("l_returnflag", DataType::Utf8, true), + Field::new("l_linestatus", DataType::Utf8, true), + Field::new("sum_qty", DataType::Decimal128(15, 2), true), + Field::new("sum_base_price", DataType::Decimal128(15, 2), true), + Field::new("sum_disc_price", DataType::Decimal128(15, 2), true), + Field::new("sum_charge", DataType::Decimal128(15, 2), true), + Field::new("avg_qty", DataType::Decimal128(15, 2), true), + Field::new("avg_price", DataType::Decimal128(15, 2), true), + Field::new("avg_disc", DataType::Decimal128(15, 2), true), + Field::new("count_order", DataType::Int64, true), + ]), + + 2 => Schema::new(vec![ + Field::new("s_acctbal", DataType::Decimal128(15, 2), true), + Field::new("s_name", DataType::Utf8, true), + Field::new("n_name", DataType::Utf8, true), + Field::new("p_partkey", DataType::Int64, true), + Field::new("p_mfgr", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + Field::new("s_phone", DataType::Utf8, true), + Field::new("s_comment", DataType::Utf8, true), + ]), + + 3 => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), + Field::new("o_orderdate", DataType::Date32, true), + Field::new("o_shippriority", DataType::Int32, true), + ]), + + 4 => Schema::new(vec![ + Field::new("o_orderpriority", DataType::Utf8, true), + Field::new("order_count", DataType::Int64, true), + ]), + + 5 => Schema::new(vec![ + Field::new("n_name", DataType::Utf8, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), + ]), + + 6 => Schema::new(vec![Field::new( + "revenue", + DataType::Decimal128(15, 2), + true, + )]), + + 7 => Schema::new(vec![ + Field::new("supp_nation", DataType::Utf8, true), + Field::new("cust_nation", DataType::Utf8, true), + Field::new("l_year", DataType::Int32, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), + ]), + + 8 => Schema::new(vec![ + Field::new("o_year", DataType::Int32, true), + Field::new("mkt_share", DataType::Decimal128(15, 2), true), + ]), + + 9 => Schema::new(vec![ + Field::new("nation", DataType::Utf8, true), + Field::new("o_year", DataType::Int32, true), + Field::new("sum_profit", DataType::Decimal128(15, 2), true), + ]), + + 10 => Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, true), + Field::new("c_name", DataType::Utf8, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), + Field::new("c_acctbal", DataType::Decimal128(15, 2), true), + Field::new("n_name", DataType::Utf8, true), + Field::new("c_address", DataType::Utf8, true), + Field::new("c_phone", DataType::Utf8, true), + Field::new("c_comment", DataType::Utf8, true), + ]), + + 11 => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, true), + Field::new("value", DataType::Decimal128(15, 2), true), + ]), + + 12 => Schema::new(vec![ + Field::new("l_shipmode", DataType::Utf8, true), + Field::new("high_line_count", DataType::Int64, true), + Field::new("low_line_count", DataType::Int64, true), + ]), + + 13 => Schema::new(vec![ + Field::new("c_count", DataType::Int64, true), + Field::new("custdist", DataType::Int64, true), + ]), + + 14 => Schema::new(vec![Field::new( + "promo_revenue", + DataType::Decimal128(38, 2), + true, + )]), + + 15 => Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, true), + Field::new("s_name", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + Field::new("s_phone", DataType::Utf8, true), + Field::new("total_revenue", DataType::Decimal128(15, 2), true), + ]), + + 16 => Schema::new(vec![ + Field::new("p_brand", DataType::Utf8, true), + Field::new("p_type", DataType::Utf8, true), + Field::new("p_size", DataType::Int32, true), + Field::new("supplier_cnt", DataType::Int64, true), + ]), + + 17 => Schema::new(vec![Field::new( + "avg_yearly", + DataType::Decimal128(38, 2), + true, + )]), + + 18 => Schema::new(vec![ + Field::new("c_name", DataType::Utf8, true), + Field::new("c_custkey", DataType::Int64, true), + Field::new("o_orderkey", DataType::Int64, true), + Field::new("o_orderdate", DataType::Date32, true), + Field::new("o_totalprice", DataType::Decimal128(15, 2), true), + Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true), + ]), + + 19 => Schema::new(vec![Field::new( + "revenue", + DataType::Decimal128(15, 2), + true, + )]), + + 20 => Schema::new(vec![ + Field::new("s_name", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + ]), + + 21 => Schema::new(vec![ + Field::new("s_name", DataType::Utf8, true), + Field::new("numwait", DataType::Int64, true), + ]), + + 22 => Schema::new(vec![ + Field::new("cntrycode", DataType::Utf8, true), + Field::new("numcust", DataType::Int64, true), + Field::new("totacctbal", DataType::Decimal128(15, 2), true), + ]), + + _ => unimplemented!(), + } +} + +/// Get the SQL statements from the specified query file +pub fn get_query_sql(query: usize) -> Result> { + if query > 0 && query < 23 { + let possibilities = vec![ + format!("queries/q{}.sql", query), + format!("benchmarks/queries/q{}.sql", query), + ]; + let mut errors = vec![]; + for filename in possibilities { + match fs::read_to_string(&filename) { + Ok(contents) => { + return Ok(contents + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect()); + } + Err(e) => errors.push(format!("{}: {}", filename, e)), + }; + } + Err(DataFusionError::Plan(format!( + "invalid query. Could not find query: {:?}", + errors + ))) + } else { + Err(DataFusionError::Plan( + "invalid query. Expected value between 1 and 22".to_owned(), + )) + } +} + +/// Conver tbl (csv) file to parquet +pub async fn convert_tbl( + input_path: &str, + output_path: &str, + file_format: &str, + partitions: usize, + batch_size: usize, + compression: Compression, +) -> Result<()> { + let output_root_path = Path::new(output_path); + for table in TPCH_TABLES { + let start = Instant::now(); + let schema = get_tpch_table_schema(table); + + let input_path = format!("{}/{}.tbl", input_path, table); + let options = CsvReadOptions::new() + .schema(&schema) + .has_header(false) + .delimiter(b'|') + .file_extension(".tbl"); + + let config = SessionConfig::new().with_batch_size(batch_size); + let ctx = SessionContext::with_config(config); + + // build plan to read the TBL file + let mut csv = ctx.read_csv(&input_path, options).await?; + + // optionally, repartition the file + if partitions > 1 { + csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))? + } + + // create the physical plan + let csv = csv.to_logical_plan()?; + let csv = ctx.create_physical_plan(&csv).await?; + + let output_path = output_root_path.join(table); + let output_path = output_path.to_str().unwrap().to_owned(); + + println!( + "Converting '{}' to {} files in directory '{}'", + &input_path, &file_format, &output_path + ); + match file_format { + "csv" => ctx.write_csv(csv, output_path).await?, + "parquet" => { + let props = WriterProperties::builder() + .set_compression(compression) + .build(); + ctx.write_parquet(csv, output_path, Some(props)).await? + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Invalid output format: {}", + other + ))); + } + } + println!("Conversion completed in {} ms", start.elapsed().as_millis()); + } + + Ok(()) +} + +/// Converts the results into a 2d array of strings, `result[row][column]` +/// Special cases nulls to NULL for testing +pub fn result_vec(results: &[RecordBatch]) -> Vec> { + let mut result = vec![]; + for batch in results { + for row_index in 0..batch.num_rows() { + let row_vec = batch + .columns() + .iter() + .map(|column| col_str(column, row_index)) + .collect(); + result.push(row_vec); + } + } + result +} + +/// convert expected schema to all utf8 so columns can be read as strings to be parsed separately +/// this is due to the fact that the csv parser cannot handle leading/trailing spaces +pub fn string_schema(schema: Schema) -> Schema { + Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + Field::name(field), + DataType::Utf8, + Field::is_nullable(field), + ) + }) + .collect::>(), + ) +} + +/// Specialised String representation +fn col_str(column: &ArrayRef, row_index: usize) -> String { + if column.is_null(row_index) { + return "NULL".to_string(); + } + + array_value_to_string(column, row_index).unwrap() +} + +pub async fn transform_actual_result( + result: Vec, + n: usize, +) -> Result> { + // to compare the recorded answers to the answers we got back from running the query, + // we need to round the decimal columns and trim the Utf8 columns + let ctx = SessionContext::new(); + let result_schema = result[0].schema(); + let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?); + let mut df = ctx.read_table(table)? + .select( + result_schema + .fields + .iter() + .map(|field| { + match Field::data_type(field) { + DataType::Decimal128(_, _) => { + // if decimal, then round it to 2 decimal places like the answers + // round() doesn't support the second argument for decimal places to round to + // this can be simplified to remove the mul and div when + // https://github.com/apache/arrow-datafusion/issues/2420 is completed + // cast it back to an over-sized Decimal with 2 precision when done rounding + let round = Box::new(Expr::ScalarFunction { + fun: datafusion::logical_expr::BuiltinScalarFunction::Round, + args: vec![col(Field::name(field)).mul(lit(100))], + }.div(lit(100))); + Expr::Alias( + Box::new(Expr::Cast(Cast::new( + round, + DataType::Decimal128(38, 2), + ))), + Field::name(field).to_string(), + ) + } + DataType::Utf8 => { + // if string, then trim it like the answers got trimmed + Expr::Alias( + Box::new(trim(col(Field::name(field)))), + Field::name(field).to_string(), + ) + } + _ => { + col(Field::name(field)) + } + } + }).collect() + )?; + if let Some(x) = QUERY_LIMIT[n - 1] { + df = df.limit(0, Some(x))?; + } + + let df = df.collect().await?; + Ok(df) +} + +pub const QUERY_LIMIT: [Option; 22] = [ + None, + Some(100), + Some(10), + None, + None, + None, + None, + None, + None, + Some(20), + None, + None, + None, + None, + None, + None, + None, + Some(100), + None, + None, + Some(100), + None, +];