From 12eece1fe9421d546fd14a856efb9905f5fb5535 Mon Sep 17 00:00:00 2001 From: lvheyang Date: Wed, 21 Jul 2021 21:02:10 +0800 Subject: [PATCH 1/2] limit pruning rule to simple expression --- datafusion/src/physical_optimizer/pruning.rs | 105 ++++++++++++++----- 1 file changed, 79 insertions(+), 26 deletions(-) diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index 36253815414a..a504ebb3c56a 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -355,17 +355,18 @@ fn build_statistics_record_batch( struct PruningExpressionBuilder<'a> { column: Column, - column_expr: &'a Expr, - scalar_expr: &'a Expr, + column_expr: Expr, + op: Operator, + scalar_expr: Expr, field: &'a Field, required_columns: &'a mut RequiredStatColumns, - reverse_operator: bool, } impl<'a> PruningExpressionBuilder<'a> { fn try_new( left: &'a Expr, right: &'a Expr, + op: Operator, schema: &'a Schema, required_columns: &'a mut RequiredStatColumns, ) -> Result { @@ -374,10 +375,10 @@ impl<'a> PruningExpressionBuilder<'a> { utils::expr_to_columns(left, &mut left_columns)?; let mut right_columns = HashSet::::new(); utils::expr_to_columns(right, &mut right_columns)?; - let (column_expr, scalar_expr, columns, reverse_operator) = + let (column_expr, scalar_expr, columns, correct_operator) = match (left_columns.len(), right_columns.len()) { - (1, 0) => (left, right, left_columns, false), - (0, 1) => (right, left, right_columns, true), + (1, 0) => (left, right, left_columns, op), + (0, 1) => (right, left, right_columns, reverse_operator(op)), _ => { // if more than one column used in expression - not supported return Err(DataFusionError::Plan( @@ -386,6 +387,12 @@ impl<'a> PruningExpressionBuilder<'a> { )); } }; + + let (column_expr, correct_operator, scalar_expr) = + match rewrite_expr_compatible(column_expr, correct_operator, scalar_expr) { + Ok(ret) => ret, + Err(e) => return Err(e), + }; let column = columns.iter().next().unwrap().clone(); let field = match schema.column_with_name(&column.flat_name()) { Some((_, f)) => f, @@ -399,40 +406,67 @@ impl<'a> PruningExpressionBuilder<'a> { Ok(Self { column, column_expr, + op: correct_operator, scalar_expr, field, required_columns, - reverse_operator, }) } - fn correct_operator(&self, op: Operator) -> Operator { - if !self.reverse_operator { - return op; - } - - match op { - Operator::Lt => Operator::Gt, - Operator::Gt => Operator::Lt, - Operator::LtEq => Operator::GtEq, - Operator::GtEq => Operator::LtEq, - _ => op, - } + fn op(&self) -> Operator { + self.op } fn scalar_expr(&self) -> &Expr { - self.scalar_expr + &self.scalar_expr } fn min_column_expr(&mut self) -> Result { self.required_columns - .min_column_expr(&self.column, self.column_expr, self.field) + .min_column_expr(&self.column, &self.column_expr, self.field) } fn max_column_expr(&mut self) -> Result { self.required_columns - .max_column_expr(&self.column, self.column_expr, self.field) + .max_column_expr(&self.column, &self.column_expr, self.field) + } +} + +fn rewrite_expr_compatible( + column_expr: &Expr, + op: Operator, + scalar_expr: &Expr, +) -> Result<(Expr, Operator, Expr)> { + match column_expr { + Expr::Column(_) => Ok((column_expr.clone(), op, scalar_expr.clone())), + // Expr::BinaryExpr { .. } => todo!(), + // Expr::Not(_) => todo!(), + Expr::Negative(c) => match c.as_ref() { + Expr::Column(_) => Ok(( + c.as_ref().clone(), + reverse_operator(op), + Expr::Negative(Box::new(scalar_expr.clone())), + )), + _ => Err(DataFusionError::Plan(format!( + "negative withm complex expression {:?} is not supported", + column_expr + ))), + }, + // Expr::Between { expr, negated, low, high } => todo!(), + // Expr::Cast { expr, data_type } => todo!(), + // Expr::TryCast { expr, data_type } => todo!(), + // Expr::Sort { expr, asc, nulls_first } => todo!(), + // Expr::ScalarFunction { fun, args } => todo!(), + // Expr::InList { expr, list, negated } => todo!(), + // Expr::Wildcard => todo!(), + _ => { + return Err(DataFusionError::Plan(format!( + "column expression {:?} is not supported", + column_expr + ))) + } } + // Ok((column_expr.clone(), op, scalar_expr.clone())) } /// replaces a column with an old name with a new name in an expression @@ -455,6 +489,16 @@ fn rewrite_column_expr( utils::rewrite_expression(expr, &expressions) } +fn reverse_operator(op: Operator) -> Operator { + match op { + Operator::Lt => Operator::Gt, + Operator::Gt => Operator::Lt, + Operator::LtEq => Operator::GtEq, + Operator::GtEq => Operator::LtEq, + _ => op, + } +} + /// Given a column reference to `column`, returns a pruning /// expression in terms of the min and max that will evaluate to true /// if the column may contain values, and false if definitely does not @@ -541,7 +585,7 @@ fn build_predicate_expression( } let expr_builder = - PruningExpressionBuilder::try_new(left, right, schema, required_columns); + PruningExpressionBuilder::try_new(left, right, op, schema, required_columns); let mut expr_builder = match expr_builder { Ok(builder) => builder, // allow partial failure in predicate expression generation @@ -550,8 +594,13 @@ fn build_predicate_expression( return Ok(unhandled); } }; - let corrected_op = expr_builder.correct_operator(op); - let statistics_expr = match corrected_op { + + let statistics_expr = build_statistics_expr(&mut expr_builder).unwrap_or(unhandled); + Ok(statistics_expr) +} + +fn build_statistics_expr(expr_builder: &mut PruningExpressionBuilder) -> Result { + let statistics_expr = match expr_builder.op() { Operator::NotEq => { // column != literal => (min, max) = literal => // !(min != literal && max != literal) ==> @@ -596,7 +645,11 @@ fn build_predicate_expression( .lt_eq(expr_builder.scalar_expr().clone()) } // other expressions are not supported - _ => unhandled, + _ => { + return Err(DataFusionError::Plan(format!( + "other expressions than (neq, eq, gt, gteq, lt, lteq) are not superted" + ))) + } }; Ok(statistics_expr) } From 2b64d75fbe1b1565e2a20874b5510d3ffa99d7df Mon Sep 17 00:00:00 2001 From: lvheyang Date: Thu, 22 Jul 2021 14:22:07 +0800 Subject: [PATCH 2/2] add Not(bool col) support --- datafusion/src/physical_optimizer/pruning.rs | 278 ++++++++++++++----- datafusion/tests/parquet_pruning.rs | 269 +++++++++++++++++- 2 files changed, 481 insertions(+), 66 deletions(-) diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index a504ebb3c56a..c6f7647b70cf 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -389,7 +389,7 @@ impl<'a> PruningExpressionBuilder<'a> { }; let (column_expr, correct_operator, scalar_expr) = - match rewrite_expr_compatible(column_expr, correct_operator, scalar_expr) { + match rewrite_expr_to_prunable(column_expr, correct_operator, scalar_expr) { Ok(ret) => ret, Err(e) => return Err(e), }; @@ -432,15 +432,32 @@ impl<'a> PruningExpressionBuilder<'a> { } } -fn rewrite_expr_compatible( +/// This function is designed to rewrite the column_expr to +/// ensure the column_expr is monotonically increasing. +/// +/// For example, +/// 1. `col > 10` +/// 2. `-col > 10` should be rewritten to `col < -10` +/// 3. `!col = true` would be rewritten to `col = !true` +/// 4. `abs(a - 10) > 0` not supported +/// +/// More rewrite rules are still in progress. +fn rewrite_expr_to_prunable( column_expr: &Expr, op: Operator, scalar_expr: &Expr, ) -> Result<(Expr, Operator, Expr)> { + if !is_compare_op(op) { + return Err(DataFusionError::Plan( + "rewrite_expr_to_prunable only support compare expression".to_string(), + )); + } + match column_expr { + // `col > lit()` Expr::Column(_) => Ok((column_expr.clone(), op, scalar_expr.clone())), - // Expr::BinaryExpr { .. } => todo!(), - // Expr::Not(_) => todo!(), + + // `-col > lit()` --> `col < -lit()` Expr::Negative(c) => match c.as_ref() { Expr::Column(_) => Ok(( c.as_ref().clone(), @@ -448,17 +465,32 @@ fn rewrite_expr_compatible( Expr::Negative(Box::new(scalar_expr.clone())), )), _ => Err(DataFusionError::Plan(format!( - "negative withm complex expression {:?} is not supported", + "negative with complex expression {:?} is not supported", column_expr ))), }, - // Expr::Between { expr, negated, low, high } => todo!(), - // Expr::Cast { expr, data_type } => todo!(), - // Expr::TryCast { expr, data_type } => todo!(), - // Expr::Sort { expr, asc, nulls_first } => todo!(), - // Expr::ScalarFunction { fun, args } => todo!(), - // Expr::InList { expr, list, negated } => todo!(), - // Expr::Wildcard => todo!(), + + // `!col = true` --> `col = !true` + Expr::Not(c) => { + if op != Operator::Eq && op != Operator::NotEq { + return Err(DataFusionError::Plan( + "Not with operator other than Eq / NotEq is not supported" + .to_string(), + )); + } + return match c.as_ref() { + Expr::Column(_) => Ok(( + c.as_ref().clone(), + reverse_operator(op), + Expr::Not(Box::new(scalar_expr.clone())), + )), + _ => Err(DataFusionError::Plan(format!( + "Not with complex expression {:?} is not supported", + column_expr + ))), + }; + } + _ => { return Err(DataFusionError::Plan(format!( "column expression {:?} is not supported", @@ -469,6 +501,18 @@ fn rewrite_expr_compatible( // Ok((column_expr.clone(), op, scalar_expr.clone())) } +fn is_compare_op(op: Operator) -> bool { + matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + ) +} + /// replaces a column with an old name with a new name in an expression fn rewrite_column_expr( expr: &Expr, @@ -600,57 +644,57 @@ fn build_predicate_expression( } fn build_statistics_expr(expr_builder: &mut PruningExpressionBuilder) -> Result { - let statistics_expr = match expr_builder.op() { - Operator::NotEq => { - // column != literal => (min, max) = literal => - // !(min != literal && max != literal) ==> - // min != literal || literal != max - let min_column_expr = expr_builder.min_column_expr()?; - let max_column_expr = expr_builder.max_column_expr()?; - min_column_expr - .not_eq(expr_builder.scalar_expr().clone()) - .or(expr_builder.scalar_expr().clone().not_eq(max_column_expr)) - } - Operator::Eq => { - // column = literal => (min, max) = literal => min <= literal && literal <= max - // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) - let min_column_expr = expr_builder.min_column_expr()?; - let max_column_expr = expr_builder.max_column_expr()?; - min_column_expr - .lt_eq(expr_builder.scalar_expr().clone()) - .and(expr_builder.scalar_expr().clone().lt_eq(max_column_expr)) - } - Operator::Gt => { - // column > literal => (min, max) > literal => max > literal - expr_builder - .max_column_expr()? - .gt(expr_builder.scalar_expr().clone()) - } - Operator::GtEq => { - // column >= literal => (min, max) >= literal => max >= literal - expr_builder - .max_column_expr()? - .gt_eq(expr_builder.scalar_expr().clone()) - } - Operator::Lt => { - // column < literal => (min, max) < literal => min < literal - expr_builder - .min_column_expr()? - .lt(expr_builder.scalar_expr().clone()) - } - Operator::LtEq => { - // column <= literal => (min, max) <= literal => min <= literal - expr_builder - .min_column_expr()? - .lt_eq(expr_builder.scalar_expr().clone()) - } - // other expressions are not supported - _ => { - return Err(DataFusionError::Plan(format!( - "other expressions than (neq, eq, gt, gteq, lt, lteq) are not superted" - ))) - } - }; + let statistics_expr = + match expr_builder.op() { + Operator::NotEq => { + // column != literal => (min, max) = literal => + // !(min != literal && max != literal) ==> + // min != literal || literal != max + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + min_column_expr + .not_eq(expr_builder.scalar_expr().clone()) + .or(expr_builder.scalar_expr().clone().not_eq(max_column_expr)) + } + Operator::Eq => { + // column = literal => (min, max) = literal => min <= literal && literal <= max + // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + min_column_expr + .lt_eq(expr_builder.scalar_expr().clone()) + .and(expr_builder.scalar_expr().clone().lt_eq(max_column_expr)) + } + Operator::Gt => { + // column > literal => (min, max) > literal => max > literal + expr_builder + .max_column_expr()? + .gt(expr_builder.scalar_expr().clone()) + } + Operator::GtEq => { + // column >= literal => (min, max) >= literal => max >= literal + expr_builder + .max_column_expr()? + .gt_eq(expr_builder.scalar_expr().clone()) + } + Operator::Lt => { + // column < literal => (min, max) < literal => min < literal + expr_builder + .min_column_expr()? + .lt(expr_builder.scalar_expr().clone()) + } + Operator::LtEq => { + // column <= literal => (min, max) <= literal => min <= literal + expr_builder + .min_column_expr()? + .lt_eq(expr_builder.scalar_expr().clone()) + } + // other expressions are not supported + _ => return Err(DataFusionError::Plan( + "expressions other than (neq, eq, gt, gteq, lt, lteq) are not superted" + .to_string(), + )), + }; Ok(statistics_expr) } @@ -1361,4 +1405,112 @@ mod tests { result ) } + + /// Creates setup for int32 chunk pruning + fn int32_setup() -> (SchemaRef, TestStatistics) { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + + let statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32( + vec![Some(-5), Some(1), Some(-11), None, Some(1)], // min + vec![Some(5), Some(11), Some(-1), None, None], // max + ), + ); + (schema, statistics) + } + + #[test] + fn prune_int32_col_gt_zero() { + let (schema, statistics) = int32_setup(); + + // Expression "i > 0" and "-i < 0" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> all rows must pass (must keep) + // i [-11, -1] ==> no rows can pass (not keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> unknown (must keep) + let expected_ret = vec![true, true, false, true, true]; + + // i > 0 + let expr = col("i").gt(lit(0)); + let p = PruningPredicate::try_new(&expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + + // -i < 0 + let expr = Expr::Negative(Box::new(col("i"))).lt(lit(0)); + let p = PruningPredicate::try_new(&expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + } + + #[test] + fn prune_int32_col_lte_zero() { + let (schema, statistics) = int32_setup(); + + // Expression "i <= 0" and "-i >= 0" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> no rows can pass (not keep) + // i [-11, -1] ==> all rows must pass (must keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> no rows can pass (not keep) + let expected_ret = vec![true, false, true, true, false]; + + // i <= 0 + let expr = col("i").lt_eq(lit(0)); + let p = PruningPredicate::try_new(&expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + + // -i >= 0 + let expr = Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)); + let p = PruningPredicate::try_new(&expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + } + + #[test] + fn prune_int32_col_eq_zero() { + let (schema, statistics) = int32_setup(); + + // Expression "i = 0" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> no rows can pass (not keep) + // i [-11, -1] ==> no rows can pass (not keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> no rows can pass (not keep) + let expected_ret = vec![true, false, false, true, false]; + + // i = 0 + let expr = col("i").eq(lit(0)); + let p = PruningPredicate::try_new(&expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + } + + #[test] + fn prune_int32_col_lt_neg_one() { + let (schema, statistics) = int32_setup(); + + // Expression "i > -1" and "-i < 1" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> all rows must pass (must keep) + // i [-11, -1] ==> no rows can pass (not keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> all rows must pass (must keep) + let expected_ret = vec![true, true, false, true, true]; + + // i > -1 + let expr = col("i").gt(lit(-1)); + let p = PruningPredicate::try_new(&expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + + // -i < 1 + let expr = Expr::Negative(Box::new(col("i"))).lt(lit(1)); + let p = PruningPredicate::try_new(&expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + } } diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 0838211f14f0..789f0810c983 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -21,10 +21,11 @@ use std::sync::Arc; use arrow::{ array::{ - Array, Date32Array, Date64Array, StringArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + Array, ArrayRef, Date32Array, Date64Array, Float64Array, Int32Array, StringArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, }, - datatypes::{Field, Schema}, + datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; @@ -177,6 +178,229 @@ async fn prune_disabled() { ); } +#[tokio::test] +async fn prune_int32_lt() { + let (expected_errors, expected_row_group_pruned, expected_results) = + (Some(0), Some(1), 11); + + // resulrt of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i < 1") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), expected_errors); + assert_eq!(output.row_groups_pruned(), expected_row_group_pruned); + assert_eq!( + output.result_rows, + expected_results, + "{}", + output.description() + ); + + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where -i > -1") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), expected_errors); + assert_eq!(output.row_groups_pruned(), expected_row_group_pruned); + assert_eq!( + output.result_rows, + expected_results, + "{}", + output.description() + ); +} + +#[tokio::test] +async fn prune_int32_eq() { + // resulrt of sql "SELECT * FROM t where i = 1" + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i = 1") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(3)); + assert_eq!(output.result_rows, 1, "{}", output.description()); +} + +#[tokio::test] +async fn prune_int32_scalar_fun_and_eq() { + // resulrt of sql "SELECT * FROM t where abs(i) = 1 and i = 1" + // only use "i = 1" to prune + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where abs(i) = 1 and i = 1") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(3)); + assert_eq!(output.result_rows, 1, "{}", output.description()); +} + +#[tokio::test] +async fn prune_int32_scalar_fun() { + // resulrt of sql "SELECT * FROM t where abs(i) = 1" is not supported + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where abs(i) = 1") + .await; + + println!("{}", output.description()); + // This should prune out groups with error, because there is not col to + // prune the row groups. + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 3, "{}", output.description()); +} + +#[tokio::test] +async fn prune_int32_complex_expr() { + // resulrt of sql "SELECT * FROM t where i+1 = 1" is not supported + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where i+1 = 1") + .await; + + println!("{}", output.description()); + // This should prune out groups with error, because there is not col to + // prune the row groups. + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 2, "{}", output.description()); +} + +#[tokio::test] +async fn prune_int32_complex_expr_subtract() { + // resulrt of sql "SELECT * FROM t where 1-i > 1" is not supported + let output = ContextWithParquet::new(Scenario::Int32) + .await + .query("SELECT * FROM t where 1-i > 1") + .await; + + println!("{}", output.description()); + // This should prune out groups with error, because there is not col to + // prune the row groups. + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 9, "{}", output.description()); +} + +#[tokio::test] +async fn prune_f64_lt() { + let (expected_errors, expected_row_group_pruned, expected_results) = + (Some(0), Some(1), 11); + + // resulrt of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + let output = ContextWithParquet::new(Scenario::Float64) + .await + .query("SELECT * FROM t where f < 1") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), expected_errors); + assert_eq!(output.row_groups_pruned(), expected_row_group_pruned); + assert_eq!( + output.result_rows, + expected_results, + "{}", + output.description() + ); + + let output = ContextWithParquet::new(Scenario::Float64) + .await + .query("SELECT * FROM t where -f > -1") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), expected_errors); + assert_eq!(output.row_groups_pruned(), expected_row_group_pruned); + assert_eq!( + output.result_rows, + expected_results, + "{}", + output.description() + ); +} + +#[tokio::test] +async fn prune_f64_scalar_fun_and_gt() { + // resulrt of sql "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1" + // only use "f >= 0" to prune + let output = ContextWithParquet::new(Scenario::Float64) + .await + .query("SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1") + .await; + + println!("{}", output.description()); + // This should prune out groups without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(2)); + assert_eq!(output.result_rows, 1, "{}", output.description()); +} + +#[tokio::test] +async fn prune_f64_scalar_fun() { + // resulrt of sql "SELECT * FROM t where abs(f-1) <= 0.000001" is not supported + let output = ContextWithParquet::new(Scenario::Float64) + .await + .query("SELECT * FROM t where abs(f-1) <= 0.000001") + .await; + + println!("{}", output.description()); + // This should prune out groups with error, because there is not col to + // prune the row groups. + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 1, "{}", output.description()); +} + +#[tokio::test] +async fn prune_f64_complex_expr() { + // resulrt of sql "SELECT * FROM t where f+1 > 1.1"" is not supported + let output = ContextWithParquet::new(Scenario::Float64) + .await + .query("SELECT * FROM t where f+1 > 1.1") + .await; + + println!("{}", output.description()); + // This should prune out groups with error, because there is not col to + // prune the row groups. + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 9, "{}", output.description()); +} + +#[tokio::test] +async fn prune_f64_complex_expr_subtract() { + // resulrt of sql "SELECT * FROM t where 1-f > 1" is not supported + let output = ContextWithParquet::new(Scenario::Float64) + .await + .query("SELECT * FROM t where 1-f > 1") + .await; + + println!("{}", output.description()); + // This should prune out groups with error, because there is not col to + // prune the row groups. + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 9, "{}", output.description()); +} + // ---------------------- // Begin test fixture // ---------------------- @@ -185,6 +409,8 @@ async fn prune_disabled() { enum Scenario { Timestamps, Dates, + Int32, + Float64, } /// Test fixture that has an execution context that has an external @@ -370,6 +596,22 @@ async fn make_test_file(scenario: Scenario) -> NamedTempFile { make_date_batch(Duration::days(3600)), ] } + Scenario::Int32 => { + vec![ + make_int32_batch(-5, 0), + make_int32_batch(-4, 1), + make_int32_batch(0, 5), + make_int32_batch(5, 10), + ] + } + Scenario::Float64 => { + vec![ + make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), + make_f64_batch(vec![-4.0, -3.0, -2.0, -1.0, 0.0]), + make_f64_batch(vec![0.0, 1.0, 2.0, 3.0, 4.0]), + make_f64_batch(vec![5.0, 6.0, 7.0, 8.0, 9.0]), + ] + } }; let schema = batches[0].schema(); @@ -475,6 +717,27 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { .unwrap() } +/// Return record batch with i32 sequence +/// +/// Columns are named +/// "i" -> Int32Array +fn make_int32_batch(start: i32, end: i32) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let v: Vec = (start..end).collect(); + let array = Arc::new(Int32Array::from(v)) as ArrayRef; + RecordBatch::try_new(schema, vec![array.clone()]).unwrap() +} + +/// Return record batch with f64 vector +/// +/// Columns are named +/// "f" -> Float64Array +fn make_f64_batch(v: Vec) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float64, true)])); + let array = Arc::new(Float64Array::from(v)) as ArrayRef; + RecordBatch::try_new(schema, vec![array.clone()]).unwrap() +} + /// Return record batch with a few rows of data for all of the supported date /// types with the specified offset (in days) ///