Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,6 @@ mod tests {
}

#[cfg(feature = "ci")]
#[ignore] // TODO produces correct result but has rounding error
#[tokio::test]
async fn verify_q9() -> Result<()> {
verify_query(9).await
Expand All @@ -681,7 +680,6 @@ mod tests {
}

#[cfg(feature = "ci")]
#[ignore] // https://github.com/apache/arrow-datafusion/issues/4023
#[tokio::test]
async fn verify_q11() -> Result<()> {
verify_query(11).await
Expand All @@ -700,7 +698,6 @@ mod tests {
}

#[cfg(feature = "ci")]
#[ignore] // https://github.com/apache/arrow-datafusion/issues/4025
#[tokio::test]
async fn verify_q14() -> Result<()> {
verify_query(14).await
Expand All @@ -719,7 +716,6 @@ mod tests {
}

#[cfg(feature = "ci")]
#[ignore] // https://github.com/apache/arrow-datafusion/issues/4026
#[tokio::test]
async fn verify_q17() -> Result<()> {
verify_query(17).await
Expand Down Expand Up @@ -896,8 +892,8 @@ mod tests {
#[cfg(feature = "ci")]
async fn verify_query(n: usize) -> Result<()> {
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::ScalarValue;
use datafusion::logical_expr::expr::Cast;
use datafusion::logical_expr::Expr;
use std::env;

let path = env::var("TPCH_DATA").unwrap_or("benchmarks/data".to_string());
Expand Down Expand Up @@ -990,7 +986,12 @@ mod tests {
}
data_type => data_type == e.data_type(),
});
assert!(schema_matches);
if !schema_matches {
panic!(
"expected_fields: {:?}\ntransformed_fields: {:?}",
expected_fields, transformed_fields
)
}

// convert both datasets to Vec<Vec<String>> for simple comparison
let expected_vec = result_vec(&expected);
Expand All @@ -1000,8 +1001,26 @@ mod tests {
assert_eq!(expected_vec.len(), actual_vec.len());

// compare each row. this works as all TPC-H queries have deterministically ordered results
for i in 0..actual_vec.len() {
assert_eq!(expected_vec[i], actual_vec[i]);
for i in 0..expected_vec.len() {
let expected_row = &expected_vec[i];
let actual_row = &actual_vec[i];
assert_eq!(expected_row.len(), actual_row.len());

for j in 0..expected.len() {
match (&expected_row[j], &actual_row[j]) {
(ScalarValue::Float64(Some(l)), ScalarValue::Float64(Some(r))) => {
// allow for rounding errors until we move to decimal types
let tolerance = 0.1;
if (l - r).abs() > tolerance {
panic!(
"Expected: {}; Actual: {}; Tolerance: {}",
l, r, tolerance
)
}
}
(l, r) => assert_eq!(format!("{:?}", l), format!("{:?}", r)),
}
}
}

Ok(())
Expand Down
59 changes: 39 additions & 20 deletions benchmarks/src/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::ArrayRef;
use arrow::array::{
Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, Int64Array,
StringArray,
};
use arrow::record_batch::RecordBatch;
use std::fs;
use std::ops::{Div, Mul};
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;

use datafusion::arrow::util::display::array_value_to_string;
use datafusion::common::ScalarValue;
use datafusion::logical_expr::Cast;
use datafusion::prelude::*;
use datafusion::{
Expand Down Expand Up @@ -229,11 +232,7 @@ pub fn get_answer_schema(n: usize) -> Schema {
Field::new("custdist", DataType::Int64, true),
]),

14 => Schema::new(vec![Field::new(
"promo_revenue",
DataType::Decimal128(38, 2),
true,
)]),
14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]),

15 => Schema::new(vec![
Field::new("s_suppkey", DataType::Int64, true),
Expand All @@ -250,11 +249,7 @@ pub fn get_answer_schema(n: usize) -> Schema {
Field::new("supplier_cnt", DataType::Int64, true),
]),

17 => Schema::new(vec![Field::new(
"avg_yearly",
DataType::Decimal128(38, 2),
true,
)]),
17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]),

18 => Schema::new(vec![
Field::new("c_name", DataType::Utf8, true),
Expand Down Expand Up @@ -389,14 +384,14 @@ pub async fn convert_tbl(

/// Converts the results into a 2d array of strings, `result[row][column]`
/// Special cases nulls to NULL for testing
pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<String>> {
pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<ScalarValue>> {
let mut result = vec![];
for batch in results {
for row_index in 0..batch.num_rows() {
let row_vec = batch
.columns()
.iter()
.map(|column| col_str(column, row_index))
.map(|column| col_to_scalar(column, row_index))
.collect();
result.push(row_vec);
}
Expand All @@ -422,13 +417,37 @@ pub fn string_schema(schema: Schema) -> Schema {
)
}

/// Specialised String representation
fn col_str(column: &ArrayRef, row_index: usize) -> String {
fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
if column.is_null(row_index) {
return "NULL".to_string();
return ScalarValue::Null;
}
match column.data_type() {
DataType::Int32 => {
let array = column.as_any().downcast_ref::<Int32Array>().unwrap();
ScalarValue::Int32(Some(array.value(row_index)))
}
DataType::Int64 => {
let array = column.as_any().downcast_ref::<Int64Array>().unwrap();
ScalarValue::Int64(Some(array.value(row_index)))
}
DataType::Float64 => {
let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
ScalarValue::Float64(Some(array.value(row_index)))
}
DataType::Decimal128(p, s) => {
let array = column.as_any().downcast_ref::<Decimal128Array>().unwrap();
ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
}
DataType::Date32 => {
let array = column.as_any().downcast_ref::<Date32Array>().unwrap();
ScalarValue::Date32(Some(array.value(row_index)))
}
DataType::Utf8 => {
let array = column.as_any().downcast_ref::<StringArray>().unwrap();
ScalarValue::Utf8(Some(array.value(row_index).to_string()))
}
other => panic!("unexpected data type in benchmark: {}", other),
}

array_value_to_string(column, row_index).unwrap()
}

pub async fn transform_actual_result(
Expand Down Expand Up @@ -460,7 +479,7 @@ pub async fn transform_actual_result(
Expr::Alias(
Box::new(Expr::Cast(Cast::new(
round,
DataType::Decimal128(38, 2),
DataType::Decimal128(15, 2),
))),
Field::name(field).to_string(),
)
Expand Down