diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 3c11b690d292..32365090a79c 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -427,11 +427,12 @@ async fn multiple_or_predicates() -> Result<()> { let expected =vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #lineitem.l_partkey [l_partkey:Int64]", - " Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]", - " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]", + " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]", - " TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 06353167c275..54d1b24e81e7 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -495,10 +495,10 @@ async fn use_between_expression_in_select_query() -> Result<()> { .unwrap() .to_string(); - // Only test that the projection exprs arecorrect, rather than entire output + // Only test that the projection exprs are correct, rather than entire output let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 BETWEEN Int64(2) AND Int64(3)]"; assert_contains!(&formatted, needle); - let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)"; + let needle = "Projection: #test.c1 >= Int64(2) AND #test.c1 <= Int64(3)"; assert_contains!(&formatted, needle); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 239939f81d66..978b79d375d3 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -28,7 +28,7 @@ use datafusion_expr::{ utils::from_plan, Expr, ExprSchemable, }; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; /// A map from expression's identifier to tuple including @@ -271,12 +271,12 @@ fn to_arrays( /// Build the "intermediate" projection plan that evaluates the extracted common expressions. fn build_project_plan( input: LogicalPlan, - affected_id: HashSet, + affected_id: BTreeSet, expr_set: &ExprSet, ) -> Result { let mut project_exprs = vec![]; let mut fields = vec![]; - let mut fields_set = HashSet::new(); + let mut fields_set = BTreeSet::new(); for id in affected_id { match expr_set.get(&id) { @@ -320,7 +320,7 @@ fn rewrite_expr( expr_set: &mut ExprSet, optimizer_config: &OptimizerConfig, ) -> Result<(Vec>, LogicalPlan)> { - let mut affected_id = HashSet::::new(); + let mut affected_id = BTreeSet::::new(); let rewrote_exprs = exprs_list .iter() @@ -482,7 +482,7 @@ struct CommonSubexprRewriter<'a> { expr_set: &'a mut ExprSet, id_array: &'a [(usize, Identifier)], /// Which identifier is replaced. - affected_id: &'a mut HashSet, + affected_id: &'a mut BTreeSet, /// the max series number we have rewritten. Other expression nodes /// with smaller series number is already replaced and shouldn't @@ -561,7 +561,7 @@ fn replace_common_expr( expr: Expr, id_array: &[(usize, Identifier)], expr_set: &mut ExprSet, - affected_id: &mut HashSet, + affected_id: &mut BTreeSet, ) -> Result { expr.rewrite(&mut CommonSubexprRewriter { expr_set, @@ -752,7 +752,7 @@ mod test { #[test] fn redundant_project_fields() { let table_scan = test_table_scan().unwrap(); - let affected_id: HashSet = + let affected_id: BTreeSet = ["c+a".to_string(), "d+a".to_string()].into_iter().collect(); let expr_set = [ ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), @@ -764,7 +764,7 @@ mod test { build_project_plan(table_scan, affected_id.clone(), &expr_set).unwrap(); let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap(); - let mut field_set = HashSet::new(); + let mut field_set = BTreeSet::new(); for field in project_2.schema().fields() { assert!(field_set.insert(field.qualified_name())); } @@ -779,7 +779,7 @@ mod test { .unwrap() .build() .unwrap(); - let affected_id: HashSet = + let affected_id: BTreeSet = ["c+a".to_string(), "d+a".to_string()].into_iter().collect(); let expr_set = [ ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), @@ -790,7 +790,7 @@ mod test { let project = build_project_plan(join, affected_id.clone(), &expr_set).unwrap(); let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap(); - let mut field_set = HashSet::new(); + let mut field_set = BTreeSet::new(); for field in project_2.schema().fields() { assert!(field_set.insert(field.qualified_name())); } diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index d1afa3543147..aa87c5318ae4 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -164,8 +164,6 @@ fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { /// returns the contained boolean value in `expr` as /// `Expr::Literal(ScalarValue::Boolean(v))`. -/// -/// panics if expr is not a literal boolean fn as_bool_lit(expr: Expr) -> Result> { match expr { Expr::Literal(ScalarValue::Boolean(v)) => Ok(v), @@ -502,7 +500,7 @@ impl<'a> ConstEvaluator<'a> { ColumnarValue::Array(a) => { if a.len() != 1 { Err(DataFusionError::Execution(format!( - "Could not evaluate the expressison, found a result of length {}", + "Could not evaluate the expression, found a result of length {}", a.len() ))) } else { @@ -803,6 +801,27 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { out_expr.rewrite(self)? } + // + // Rules for Between + // + + // a between 3 and 5 --> a >= 3 AND a <=5 + // a not between 3 and 5 --> a < 3 OR a > 5 + Between { + expr, + low, + high, + negated, + } => { + if negated { + let l = *expr.clone(); + let r = *expr; + or(l.lt(*low), r.gt(*high)) + } else { + and(expr.clone().gt_eq(*low), expr.lt_eq(*high)) + } + } + expr => { // no additional rewrites possible expr @@ -1555,8 +1574,13 @@ mod tests { high: Box::new(lit(10)), }; let expr = expr.or(lit_bool_null()); - let result = simplify(expr.clone()); - assert_eq!(expr, result); + let result = simplify(expr); + + let expected_expr = or( + and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))), + lit_bool_null(), + ); + assert_eq!(expected_expr, result); } #[test] @@ -1579,8 +1603,8 @@ mod tests { assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),); // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL) - // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10` - // and should not be rewritten + // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)` + // and the Boolean(NULL) should remain let expr = Expr::Between { expr: Box::new(col("c1")), negated: false, @@ -1588,8 +1612,40 @@ mod tests { high: Box::new(lit(10)), }; let expr = expr.and(lit_bool_null()); - let result = simplify(expr.clone()); - assert_eq!(expr, result); + let result = simplify(expr); + + let expected_expr = and( + and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))), + lit_bool_null(), + ); + assert_eq!(expected_expr, result); + } + + #[test] + fn simplify_expr_between() { + // c2 between 3 and 4 is c2 >= 3 and c2 <= 4 + let expr = Expr::Between { + expr: Box::new(col("c2")), + negated: false, + low: Box::new(lit(3)), + high: Box::new(lit(4)), + }; + assert_eq!( + simplify(expr), + and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4))) + ); + + // c2 not between 3 and 4 is c2 < 3 or c2 > 4 + let expr = Expr::Between { + expr: Box::new(col("c2")), + negated: true, + low: Box::new(lit(3)), + high: Box::new(lit(4)), + }; + assert_eq!( + simplify(expr), + or(col("c2").lt(lit(3)), col("c2").gt(lit(4))) + ); } // ------------------------------ @@ -2167,7 +2223,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: #test.d NOT BETWEEN Int32(1) AND Int32(10) AS NOT test.d BETWEEN Int32(1) AND Int32(10)\ + let expected = "Filter: #test.d < Int32(1) OR #test.d > Int32(10) AS NOT test.d BETWEEN Int32(1) AND Int32(10)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); @@ -2188,7 +2244,7 @@ mod tests { .unwrap() .build() .unwrap(); - let expected = "Filter: #test.d BETWEEN Int32(1) AND Int32(10) AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\ + let expected = "Filter: #test.d >= Int32(1) AND #test.d <= Int32(10) AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected);