diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index 3a5a64c6f668..c65733bd7526 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -32,7 +32,7 @@ use std::{collections::HashSet, sync::Arc}; use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, - datatypes::{Field, Schema, SchemaRef}, + datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; @@ -86,12 +86,8 @@ pub struct PruningPredicate { schema: SchemaRef, /// Actual pruning predicate (rewritten in terms of column min/max statistics) predicate_expr: Arc, - /// The statistics required to evaluate this predicate: - /// * The column name in the input schema - /// * Statistics type (e.g. Min or Max) - /// * The field the statistics value should be placed in for - /// pruning predicate evaluation - stat_column_req: Vec<(String, StatisticsType, Field)>, + /// The statistics required to evaluate this predicate + required_columns: RequiredStatColumns, } impl PruningPredicate { @@ -116,10 +112,10 @@ impl PruningPredicate { /// `(column_min / 2) <= 4 && 4 <= (column_max / 2))` pub fn try_new(expr: &Expr, schema: SchemaRef) -> Result { // build predicate expression once - let mut stat_column_req = Vec::<(String, StatisticsType, Field)>::new(); + let mut required_columns = RequiredStatColumns::new(); let logical_predicate_expr = - build_predicate_expression(expr, schema.as_ref(), &mut stat_column_req)?; - let stat_fields = stat_column_req + build_predicate_expression(expr, schema.as_ref(), &mut required_columns)?; + let stat_fields = required_columns .iter() .map(|(_, _, f)| f.clone()) .collect::>(); @@ -133,7 +129,7 @@ impl PruningPredicate { Ok(Self { schema, predicate_expr, - stat_column_req, + required_columns, }) } @@ -148,10 +144,16 @@ impl PruningPredicate { /// Note this function takes a slice of statistics as a parameter /// to amortize the cost of the evaluation of the predicate /// against a single record batch. + /// + /// Note: the predicate passed to `prune` should be simplified as + /// much as possible (e.g. this pass doesn't handle some + /// expressions like `b = false`, but it does handle the + /// simplified version `b`. The predicates are simplified via the + /// ConstantFolding optimizer pass pub fn prune(&self, statistics: &S) -> Result> { // build statistics record batch let predicate_array = - build_statistics_record_batch(statistics, &self.stat_column_req) + build_statistics_record_batch(statistics, &self.required_columns) .and_then(|statistics_batch| { // execute predicate expression self.predicate_expr.evaluate(&statistics_batch) @@ -189,9 +191,100 @@ impl PruningPredicate { } } +/// Handles creating references to the min/max statistics +/// for columns as well as recording which statistics are needed +#[derive(Debug, Default, Clone)] +struct RequiredStatColumns { + /// The statistics required to evaluate this predicate: + /// * The column name in the input schema + /// * Statistics type (e.g. Min or Max) + /// * The field the statistics value should be placed in for + /// pruning predicate evaluation + columns: Vec<(String, StatisticsType, Field)>, +} + +impl RequiredStatColumns { + fn new() -> Self { + Self::default() + } + + /// Retur an iterator over items in columns (see doc on + /// `self.columns` for details) + fn iter(&self) -> impl Iterator { + self.columns.iter() + } + + fn is_stat_column_missing( + &self, + column_name: &str, + statistics_type: StatisticsType, + ) -> bool { + !self + .columns + .iter() + .any(|(c, t, _f)| c == column_name && t == &statistics_type) + } + + /// Rewrites column_expr so that all appearances of column_name + /// are replaced with a reference to either the min or max + /// statistics column, while keeping track that a reference to the statistics + /// column is required + /// + /// for example, an expression like `col("foo") > 5`, when called + /// with Max would result in an expression like `col("foo_max") > + /// 5` with the approprate entry noted in self.columns + fn stat_column_expr( + &mut self, + column_name: &str, + column_expr: &Expr, + field: &Field, + stat_type: StatisticsType, + suffix: &str, + ) -> Result { + let stat_column_name = format!("{}_{}", column_name, suffix); + let stat_field = Field::new( + stat_column_name.as_str(), + field.data_type().clone(), + field.is_nullable(), + ); + if self.is_stat_column_missing(column_name, stat_type) { + // only add statistics column if not previously added + self.columns + .push((column_name.to_string(), stat_type, stat_field)); + } + rewrite_column_expr(column_expr, column_name, stat_column_name.as_str()) + } + + /// rewrite col --> col_min + fn min_column_expr( + &mut self, + column_name: &str, + column_expr: &Expr, + field: &Field, + ) -> Result { + self.stat_column_expr(column_name, column_expr, field, StatisticsType::Min, "min") + } + + /// rewrite col --> col_max + fn max_column_expr( + &mut self, + column_name: &str, + column_expr: &Expr, + field: &Field, + ) -> Result { + self.stat_column_expr(column_name, column_expr, field, StatisticsType::Max, "max") + } +} + +impl From> for RequiredStatColumns { + fn from(columns: Vec<(String, StatisticsType, Field)>) -> Self { + Self { columns } + } +} + /// Build a RecordBatch from a list of statistics, creating arrays, /// with one row for each PruningStatistics and columns specified in -/// in the stat_column_req parameter. +/// in the required_columns parameter. /// /// For example, if the requested columns are /// ```text @@ -216,12 +309,12 @@ impl PruningPredicate { /// ``` fn build_statistics_record_batch( statistics: &S, - stat_column_req: &[(String, StatisticsType, Field)], + required_columns: &RequiredStatColumns, ) -> Result { let mut fields = Vec::::new(); let mut arrays = Vec::::new(); // For each needed statistics column: - for (column_name, statistics_type, stat_field) in stat_column_req { + for (column_name, statistics_type, stat_field) in required_columns.iter() { let data_type = stat_field.data_type(); let num_containers = statistics.num_containers(); @@ -258,7 +351,7 @@ struct PruningExpressionBuilder<'a> { column_expr: &'a Expr, scalar_expr: &'a Expr, field: &'a Field, - stat_column_req: &'a mut Vec<(String, StatisticsType, Field)>, + required_columns: &'a mut RequiredStatColumns, reverse_operator: bool, } @@ -267,7 +360,7 @@ impl<'a> PruningExpressionBuilder<'a> { left: &'a Expr, right: &'a Expr, schema: &'a Schema, - stat_column_req: &'a mut Vec<(String, StatisticsType, Field)>, + required_columns: &'a mut RequiredStatColumns, ) -> Result { // find column name; input could be a more complicated expression let mut left_columns = HashSet::::new(); @@ -301,7 +394,7 @@ impl<'a> PruningExpressionBuilder<'a> { column_expr, scalar_expr, field, - stat_column_req, + required_columns, reverse_operator, }) } @@ -324,42 +417,20 @@ impl<'a> PruningExpressionBuilder<'a> { self.scalar_expr } - fn is_stat_column_missing(&self, statistics_type: StatisticsType) -> bool { - !self - .stat_column_req - .iter() - .any(|(c, t, _f)| c == &self.column_name && t == &statistics_type) - } - - fn stat_column_expr( - &mut self, - stat_type: StatisticsType, - suffix: &str, - ) -> Result { - let stat_column_name = format!("{}_{}", self.column_name, suffix); - let stat_field = Field::new( - stat_column_name.as_str(), - self.field.data_type().clone(), - self.field.is_nullable(), - ); - if self.is_stat_column_missing(stat_type) { - // only add statistics column if not previously added - self.stat_column_req - .push((self.column_name.clone(), stat_type, stat_field)); - } - rewrite_column_expr( - self.column_expr, - self.column_name.as_str(), - stat_column_name.as_str(), - ) - } - fn min_column_expr(&mut self) -> Result { - self.stat_column_expr(StatisticsType::Min, "min") + self.required_columns.min_column_expr( + &self.column_name, + &self.column_expr, + self.field, + ) } fn max_column_expr(&mut self) -> Result { - self.stat_column_expr(StatisticsType::Max, "max") + self.required_columns.max_column_expr( + &self.column_name, + &self.column_expr, + self.field, + ) } } @@ -383,6 +454,46 @@ fn rewrite_column_expr( utils::rewrite_expression(&expr, &expressions) } +/// Given a column reference to `column_name`, returns a pruning +/// expression in terms of the min and max that will evaluate to true +/// if the column may contain values, and false if definitely does not +/// contain values +fn build_single_column_expr( + column_name: &str, + schema: &Schema, + required_columns: &mut RequiredStatColumns, + is_not: bool, // if true, treat as !col +) -> Option { + use crate::logical_plan; + let field = schema.field_with_name(column_name).ok()?; + + if matches!(field.data_type(), &DataType::Boolean) { + let col_ref = logical_plan::col(column_name); + + let min = required_columns + .min_column_expr(column_name, &col_ref, field) + .ok()?; + let max = required_columns + .max_column_expr(column_name, &col_ref, field) + .ok()?; + + // remember -- we want an expression that is: + // TRUE: if there may be rows that match + // FALSE: if there are no rows that match + if is_not { + // The only way we know a column couldn't match is if both the min and max are true + // !(min && max) + Some((min.and(max)).not()) + } else { + // the only way we know a column couldn't match is if both the min and max are false + // !(!min && !max) --> min || max + Some(min.or(max)) + } + } else { + None + } +} + /// Translate logical filter expression into pruning predicate /// expression that will evaluate to FALSE if it can be determined no /// rows between the min/max values could pass the predicates. @@ -391,28 +502,47 @@ fn rewrite_column_expr( fn build_predicate_expression( expr: &Expr, schema: &Schema, - stat_column_req: &mut Vec<(String, StatisticsType, Field)>, + required_columns: &mut RequiredStatColumns, ) -> Result { use crate::logical_plan; + + // Returned for unsupported expressions. Such expressions are + // converted to TRUE. This can still be useful when multiple + // conditions are joined using AND such as: column > 10 AND TRUE + let unhandled = logical_plan::lit(true); + // predicate expression can only be a binary expression let (left, op, right) = match expr { Expr::BinaryExpr { left, op, right } => (left, *op, right), + Expr::Column(name) => { + let expr = build_single_column_expr(&name, schema, required_columns, false) + .unwrap_or(unhandled); + return Ok(expr); + } + // match !col (don't do so recursively) + Expr::Not(input) => { + if let Expr::Column(name) = input.as_ref() { + let expr = + build_single_column_expr(&name, schema, required_columns, true) + .unwrap_or(unhandled); + return Ok(expr); + } else { + return Ok(unhandled); + } + } _ => { - // unsupported expression - replace with TRUE - // this can still be useful when multiple conditions are joined using AND - // such as: column > 10 AND TRUE - return Ok(logical_plan::lit(true)); + return Ok(unhandled); } }; if op == Operator::And || op == Operator::Or { - let left_expr = build_predicate_expression(left, schema, stat_column_req)?; - let right_expr = build_predicate_expression(right, schema, stat_column_req)?; + let left_expr = build_predicate_expression(left, schema, required_columns)?; + let right_expr = build_predicate_expression(right, schema, required_columns)?; return Ok(logical_plan::binary_expr(left_expr, op, right_expr)); } let expr_builder = - PruningExpressionBuilder::try_new(left, right, schema, stat_column_req); + PruningExpressionBuilder::try_new(left, right, schema, required_columns); let mut expr_builder = match expr_builder { Ok(builder) => builder, // allow partial failure in predicate expression generation @@ -508,6 +638,16 @@ mod tests { } } + fn new_bool( + min: impl IntoIterator>, + max: impl IntoIterator>, + ) -> Self { + Self { + min: Arc::new(min.into_iter().collect::()), + max: Arc::new(max.into_iter().collect::()), + } + } + fn min(&self) -> Option { Some(self.min.clone()) } @@ -591,7 +731,7 @@ mod tests { #[test] fn test_build_statistics_record_batch() { // Request a record batch with of s1_min, s2_max, s3_max, s3_min - let stat_column_req = vec![ + let required_columns = RequiredStatColumns::from(vec![ // min of original column s1, named s1_min ( "s1".to_string(), @@ -616,7 +756,7 @@ mod tests { StatisticsType::Min, Field::new("s3_min", DataType::Utf8, true), ), - ]; + ]); let statistics = TestStatistics::new() .with( @@ -641,7 +781,8 @@ mod tests { ), ); - let batch = build_statistics_record_batch(&statistics, &stat_column_req).unwrap(); + let batch = + build_statistics_record_batch(&statistics, &required_columns).unwrap(); let expected = vec![ "+--------+--------+--------+--------+", "| s1_min | s2_max | s3_max | s3_min |", @@ -662,7 +803,7 @@ mod tests { // which is what Parquet does // Request a record batch with of s1_min as a timestamp - let stat_column_req = vec![( + let required_columns = RequiredStatColumns::from(vec![( "s1".to_string(), StatisticsType::Min, Field::new( @@ -670,7 +811,7 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, None), true, ), - )]; + )]); // Note the statistics pass back i64 (not timestamp) let statistics = OneContainerStats { @@ -679,7 +820,8 @@ mod tests { num_containers: 1, }; - let batch = build_statistics_record_batch(&statistics, &stat_column_req).unwrap(); + let batch = + build_statistics_record_batch(&statistics, &required_columns).unwrap(); let expected = vec![ "+-------------------------------+", "| s1_min |", @@ -693,7 +835,7 @@ mod tests { #[test] fn test_build_statistics_no_stats() { - let stat_column_req = vec![]; + let required_columns = RequiredStatColumns::new(); let statistics = OneContainerStats { min_values: Some(Arc::new(Int64Array::from(vec![Some(10)]))), @@ -702,7 +844,7 @@ mod tests { }; let result = - build_statistics_record_batch(&statistics, &stat_column_req).unwrap_err(); + build_statistics_record_batch(&statistics, &required_columns).unwrap_err(); assert!( result.to_string().contains("Invalid argument error"), "{}", @@ -715,11 +857,11 @@ mod tests { // Test requesting a Utf8 column when the stats return some other type // Request a record batch with of s1_min as a timestamp - let stat_column_req = vec![( + let required_columns = RequiredStatColumns::from(vec![( "s1".to_string(), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), - )]; + )]); // Note the statistics return binary (which can't be cast to string) let statistics = OneContainerStats { @@ -728,7 +870,8 @@ mod tests { num_containers: 1, }; - let batch = build_statistics_record_batch(&statistics, &stat_column_req).unwrap(); + let batch = + build_statistics_record_batch(&statistics, &required_columns).unwrap(); let expected = vec![ "+--------+", "| s1_min |", @@ -743,11 +886,11 @@ mod tests { #[test] fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays - let stat_column_req = vec![( + let required_columns = RequiredStatColumns::from(vec![( "s1".to_string(), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), - )]; + )]); // Note the statistics pass back i64 (not timestamp) let statistics = OneContainerStats { @@ -757,7 +900,7 @@ mod tests { }; let result = - build_statistics_record_batch(&statistics, &stat_column_req).unwrap_err(); + build_statistics_record_batch(&statistics, &required_columns).unwrap_err(); assert!( result .to_string() @@ -774,12 +917,14 @@ mod tests { // test column on the left let expr = col("c1").eq(lit(1)); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); // test column on the right let expr = lit(1).eq(col("c1")); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); Ok(()) @@ -792,12 +937,14 @@ mod tests { // test column on the left let expr = col("c1").gt(lit(1)); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); // test column on the right let expr = lit(1).lt(col("c1")); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); Ok(()) @@ -810,11 +957,13 @@ mod tests { // test column on the left let expr = col("c1").gt_eq(lit(1)); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); // test column on the right let expr = lit(1).lt_eq(col("c1")); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); Ok(()) @@ -827,12 +976,14 @@ mod tests { // test column on the left let expr = col("c1").lt(lit(1)); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); // test column on the right let expr = lit(1).gt(col("c1")); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); Ok(()) @@ -845,11 +996,13 @@ mod tests { // test column on the left let expr = col("c1").lt_eq(lit(1)); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); // test column on the right let expr = lit(1).gt_eq(col("c1")); - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); Ok(()) @@ -865,7 +1018,8 @@ mod tests { // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); let expected_expr = "#c1_min Lt Int32(1) And Boolean(true)"; - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); Ok(()) @@ -880,46 +1034,101 @@ mod tests { // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 expression let expr = col("c1").lt(lit(1)).or(col("c2").modulus(lit(2))); let expected_expr = "#c1_min Lt Int32(1) Or Boolean(true)"; - let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); Ok(()) } #[test] - fn row_group_predicate_stat_column_req() -> Result<()> { + fn row_group_predicate_not() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let expected_expr = "Boolean(true)"; + + let expr = col("c1").not(); + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_not_bool() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); + let expected_expr = "NOT #c1_min And #c1_max"; + + let expr = col("c1").not(); + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_bool() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); + let expected_expr = "#c1_min Or #c1_max"; + + let expr = col("c1"); + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_lt_bool() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); + let expected_expr = "#c1_min Lt Boolean(true)"; + + // DF doesn't support arithmetic on boolean columns so + // this predicate will error when evaluated + let expr = col("c1").lt(lit(true)); + let predicate_expr = + build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; + assert_eq!(format!("{:?}", predicate_expr), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_required_columns() -> Result<()> { let schema = Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ]); - let mut stat_column_req = vec![]; + let mut required_columns = RequiredStatColumns::new(); // c1 < 1 and (c2 = 2 or c2 = 3) let expr = col("c1") .lt(lit(1)) .and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3)))); let expected_expr = "#c1_min Lt Int32(1) And #c2_min LtEq Int32(2) And Int32(2) LtEq #c2_max Or #c2_min LtEq Int32(3) And Int32(3) LtEq #c2_max"; let predicate_expr = - build_predicate_expression(&expr, &schema, &mut stat_column_req)?; + build_predicate_expression(&expr, &schema, &mut required_columns)?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); // c1 < 1 should add c1_min let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( - stat_column_req[0], + required_columns.columns[0], ("c1".to_owned(), StatisticsType::Min, c1_min_field) ); // c2 = 2 should add c2_min and c2_max let c2_min_field = Field::new("c2_min", DataType::Int32, false); assert_eq!( - stat_column_req[1], + required_columns.columns[1], ("c2".to_owned(), StatisticsType::Min, c2_min_field) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); assert_eq!( - stat_column_req[2], + required_columns.columns[2], ("c2".to_owned(), StatisticsType::Max, c2_max_field) ); // c2 = 3 shouldn't add any new statistics fields - assert_eq!(stat_column_req.len(), 3); + assert_eq!(required_columns.columns.len(), 3); Ok(()) } @@ -927,8 +1136,8 @@ mod tests { #[test] fn prune_api() { let schema = Arc::new(Schema::new(vec![ - Field::new("s1", DataType::Utf8, false), - Field::new("s2", DataType::Int32, false), + Field::new("s1", DataType::Utf8, true), + Field::new("s2", DataType::Int32, true), ])); // Prune using s2 > 5 @@ -953,4 +1162,92 @@ mod tests { assert_eq!(result, expected); } + + /// Creates setup for boolean chunk pruning + /// + /// For predicate "b1" (boolean expr) + /// b1 [false, false] ==> no rows can pass (not keep) + /// b1 [false, true] ==> some rows could pass (must keep) + /// b1 [true, true] ==> all rows must pass (must keep) + /// b1 [NULL, NULL] ==> unknown (must keep) + /// b1 [false, NULL] ==> unknown (must keep) + /// + /// For predicate "!b1" (boolean expr) + /// b1 [false, false] ==> all rows pass (must keep) + /// b1 [false, true] ==> some rows could pass (must keep) + /// b1 [true, true] ==> no rows can pass (not keep) + /// b1 [NULL, NULL] ==> unknown (must keep) + /// b1 [false, NULL] ==> unknown (must keep) + fn bool_setup() -> (SchemaRef, TestStatistics, Vec, Vec) { + let schema = + Arc::new(Schema::new(vec![Field::new("b1", DataType::Boolean, true)])); + + let statistics = TestStatistics::new().with( + "b1", + ContainerStats::new_bool( + vec![Some(false), Some(false), Some(true), None, Some(false)], // min + vec![Some(false), Some(true), Some(true), None, None], // max + ), + ); + let expected_true = vec![false, true, true, true, true]; + let expected_false = vec![true, true, false, true, true]; + + (schema, statistics, expected_true, expected_false) + } + + #[test] + fn prune_bool_column() { + let (schema, statistics, expected_true, _) = bool_setup(); + + // b1 + let expr = col("b1"); + let p = PruningPredicate::try_new(&expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_true); + } + + #[test] + fn prune_bool_not_column() { + let (schema, statistics, _, expected_false) = bool_setup(); + + // !b1 + let expr = col("b1").not(); + let p = PruningPredicate::try_new(&expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_false); + } + + #[test] + fn prune_bool_column_eq_true() { + let (schema, statistics, _, _) = bool_setup(); + + // b1 = true + let expr = col("b1").eq(lit(true)); + let p = PruningPredicate::try_new(&expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap_err(); + assert!( + result.to_string().contains( + "Data type Boolean not supported for scalar operation on dyn array" + ), + "{}", + result + ) + } + + #[test] + fn prune_bool_not_column_eq_true() { + let (schema, statistics, _, _) = bool_setup(); + + // !b1 = true + let expr = col("b1").not().eq(lit(true)); + let p = PruningPredicate::try_new(&expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap_err(); + assert!( + result.to_string().contains( + "Data type Boolean not supported for scalar operation on dyn array" + ), + "{}", + result + ) + } }