diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index 73fe8574a627..a1d1806f9189 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping - Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") + Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") Inner Join: customer.c_nationkey = n2.n_nationkey Inner Join: supplier.s_nationkey = n1.n_nationkey Inner Join: orders.o_custkey = customer.c_custkey diff --git a/datafusion/core/src/physical_plan/file_format/row_filter.rs b/datafusion/core/src/physical_plan/file_format/row_filter.rs index 2ac55d368bf9..dd9c8fb650fd 100644 --- a/datafusion/core/src/physical_plan/file_format/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/row_filter.rs @@ -22,7 +22,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; -use datafusion_expr::{Expr, Operator}; +use datafusion_expr::Expr; use datafusion_optimizer::utils::split_conjunction_owned; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; @@ -253,7 +253,7 @@ pub fn build_row_filter( metadata: &ParquetMetaData, reorder_predicates: bool, ) -> Result> { - let predicates = split_conjunction_owned(expr, Operator::And); + let predicates = split_conjunction_owned(expr); let mut candidates: Vec = predicates .into_iter() diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 1ba8cf7ac42e..2ff4947b3214 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1468,15 +1468,10 @@ async fn reduce_left_join_2() -> Result<()> { .expect(&msg); let state = ctx.state(); let plan = state.optimize(&plan)?; - - // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` - // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` - // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. - let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index a768b0a7fed5..6396f1fbfd6c 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -14,7 +14,6 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan -use crate::utils::{split_conjunction, CnfHelper}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ @@ -29,7 +28,6 @@ use datafusion_expr::{ utils::{expr_to_columns, exprlist_to_columns, from_plan}, Expr, Operator, TableProviderFilterPushDown, }; -use log::error; use std::collections::{HashMap, HashSet}; use std::iter::once; @@ -532,14 +530,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let filter_cnf = filter.predicate().clone().rewrite(&mut CnfHelper::new()); - let predicates = match filter_cnf { - Ok(ref expr) => split_conjunction(expr), - Err(e) => { - error!("Fail at CnfHelper rewrite: {}.", e); - split_conjunction(filter.predicate()) - } - }; + let predicates = utils::split_conjunction(filter.predicate()); predicates .into_iter() @@ -962,30 +953,6 @@ mod tests { Ok(()) } - #[test] - fn filter_keep_partial_agg() -> Result<()> { - let table_scan = test_table_scan()?; - let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); - let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); - let filter = f1.or(f2); - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? - .filter(filter)? - .build()?; - // filter of aggregate is after aggregation since they are non-commutative - // (c =1 AND b > 2) OR (c = 1 AND b > 3) - // rewrite to CNF - // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) - - let expected = "\ - Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ - \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ - \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> { @@ -2377,7 +2344,7 @@ mod tests { .filter(filter)? .build()?; - let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\ + let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index f088085b8812..130df3e0e6ef 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use datafusion_expr::{ and, col, @@ -84,7 +84,7 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// /// # Example /// ``` -/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_expr::{col, lit}; /// # use datafusion_optimizer::utils::split_conjunction_owned; /// // a=1 AND b=2 /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); @@ -96,23 +96,23 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// ]; /// /// // use split_conjunction_owned to split them -/// assert_eq!(split_conjunction_owned(expr, Operator::And), split); +/// assert_eq!(split_conjunction_owned(expr), split); /// ``` -pub fn split_conjunction_owned(expr: Expr, op: Operator) -> Vec { - split_conjunction_owned_impl(expr, op, vec![]) +pub fn split_conjunction_owned(expr: Expr) -> Vec { + split_conjunction_owned_impl(expr, vec![]) } -fn split_conjunction_owned_impl( - expr: Expr, - operator: Operator, - mut exprs: Vec, -) -> Vec { +fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { - let exprs = split_conjunction_owned_impl(*left, Operator::And, exprs); - split_conjunction_owned_impl(*right, Operator::And, exprs) + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + let exprs = split_conjunction_owned_impl(*left, exprs); + split_conjunction_owned_impl(*right, exprs) } - Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, Operator::And, exprs), + Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs), other => { exprs.push(other); exprs @@ -120,149 +120,6 @@ fn split_conjunction_owned_impl( } } -/// Converts an expression to conjunctive normal form (CNF). -/// -/// The following expression is in CNF: -/// `(a OR b) AND (c OR d)` -/// The following is not in CNF: -/// `(a AND b) OR c`. -/// But could be rewrite to a CNF expression: -/// `(a OR c) AND (b OR c)`. -/// -/// # Example -/// ``` -/// # use datafusion_expr::{col, lit}; -/// # use datafusion_expr::expr_rewriter::ExprRewritable; -/// # use datafusion_optimizer::utils::CnfHelper; -/// // (a=1 AND b=2)OR c = 3 -/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); -/// let expr2 = col("c").eq(lit(3)); -/// let expr = expr1.or(expr2); -/// -/// //(a=1 or c=3)AND(b=2 or c=3) -/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); -/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); -/// let expect = expr1.and(expr2); -/// // use split_conjunction_owned to split them -/// assert_eq!(expr.rewrite(& mut CnfHelper::new()).unwrap(), expect); -/// ``` -/// -pub struct CnfHelper { - max_count: usize, - current_count: usize, - exprs: Vec, - original_expr: Option, -} - -impl CnfHelper { - pub fn new() -> Self { - CnfHelper { - max_count: 50, - current_count: 0, - exprs: vec![], - original_expr: None, - } - } - - pub fn new_with_max_count(max_count: usize) -> Self { - CnfHelper { - max_count, - current_count: 0, - exprs: vec![], - original_expr: None, - } - } - - fn increment_and_check_overload(&mut self) -> bool { - self.current_count += 1; - self.current_count >= self.max_count - } -} - -impl ExprRewriter for CnfHelper { - fn pre_visit(&mut self, expr: &Expr) -> Result { - let is_root = self.original_expr.is_none(); - if is_root { - self.original_expr = Some(expr.clone()); - } - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match op { - Operator::And => { - if self.increment_and_check_overload() { - return Ok(RewriteRecursion::Mutate); - } - } - // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) - Operator::Or => { - let left_and_split = - split_conjunction_owned(*left.clone(), Operator::And); - let right_and_split = - split_conjunction_owned(*right.clone(), Operator::And); - // Avoid create to much Expr like in tpch q19. - let lc = split_conjunction_owned(*left.clone(), Operator::Or) - .into_iter() - .flat_map(|e| split_conjunction_owned(e, Operator::And)) - .count(); - let rc = split_conjunction_owned(*right.clone(), Operator::Or) - .into_iter() - .flat_map(|e| split_conjunction_owned(e, Operator::And)) - .count(); - self.current_count += lc * rc - 1; - if self.increment_and_check_overload() { - return Ok(RewriteRecursion::Mutate); - } - left_and_split.iter().for_each(|l| { - right_and_split.iter().for_each(|r| { - self.exprs.push(Expr::BinaryExpr(BinaryExpr { - left: Box::new(l.clone()), - op: Operator::Or, - right: Box::new(r.clone()), - })) - }) - }); - return Ok(RewriteRecursion::Mutate); - } - _ => { - if self.increment_and_check_overload() { - return Ok(RewriteRecursion::Mutate); - } - self.exprs.push(expr.clone()); - return Ok(RewriteRecursion::Stop); - } - } - } - other => { - if self.increment_and_check_overload() { - return Ok(RewriteRecursion::Mutate); - } - self.exprs.push(other.clone()); - return Ok(RewriteRecursion::Stop); - } - } - if is_root { - Ok(RewriteRecursion::Continue) - } else { - Ok(RewriteRecursion::Skip) - } - } - - fn mutate(&mut self, _expr: Expr) -> Result { - if self.current_count >= self.max_count { - Ok(self.original_expr.as_ref().unwrap().clone()) - } else { - Ok(conjunction(self.exprs.clone()) - .unwrap_or_else(|| self.original_expr.as_ref().unwrap().clone())) - } - } -} - -impl Default for CnfHelper { - fn default() -> Self { - Self::new() - } -} - /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. @@ -612,7 +469,7 @@ mod tests { use super::*; use arrow::datatypes::DataType; use datafusion_common::Column; - use datafusion_expr::{col, lit, or, utils::expr_to_columns}; + use datafusion_expr::{col, lit, utils::expr_to_columns}; use std::collections::HashSet; use std::ops::Add; @@ -653,16 +510,13 @@ mod tests { #[test] fn test_split_conjunction_owned() { let expr = col("a"); - assert_eq!( - split_conjunction_owned(expr.clone(), Operator::And), - vec![expr] - ); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); } #[test] fn test_split_conjunction_owned_two() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), vec![col("a").eq(lit(5)), col("b")] ); } @@ -670,10 +524,7 @@ mod tests { #[test] fn test_split_conjunction_owned_alias() { assert_eq!( - split_conjunction_owned( - col("a").eq(lit(5)).and(col("b").alias("the_alias")), - Operator::And - ), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), vec![ col("a").eq(lit(5)), // no alias on b @@ -719,10 +570,7 @@ mod tests { #[test] fn test_split_conjunction_owned_or() { let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!( - split_conjunction_owned(expr.clone(), Operator::And), - vec![expr] - ); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); } #[test] @@ -815,84 +663,4 @@ mod tests { "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" ) } - - #[test] - fn test_rewrite_cnf() { - let a_1 = col("a").eq(lit(1i64)); - let a_2 = col("a").eq(lit(2i64)); - - let b_1 = col("b").eq(lit(1i64)); - let b_2 = col("b").eq(lit(2i64)); - - // Test rewrite on a1_and_b2 and a2_and_b1 -> not change - let mut helper = CnfHelper::new(); - let expr1 = and(a_1.clone(), b_2.clone()); - let expect = expr1.clone(); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); - - // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1) - let mut helper = CnfHelper::new(); - let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); - let expect = and(a_1.clone(), b_2.clone()) - .and(a_2.clone()) - .and(b_1.clone()); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); - - // Test rewrite on a1_or_b2 -> not change - let mut helper = CnfHelper::new(); - let expr1 = or(a_1.clone(), b_2.clone()); - let expect = expr1.clone(); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); - - // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1 - let mut helper = CnfHelper::new(); - let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); - let a1_or_a2 = or(a_1.clone(), a_2.clone()); - let a1_or_b1 = or(a_1.clone(), b_1.clone()); - let b2_or_a2 = or(b_2.clone(), a_2.clone()); - let b2_or_b1 = or(b_2.clone(), b_1.clone()); - let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); - - // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1) - let mut helper = CnfHelper::new(); - let a1_or_b2 = or(a_1.clone(), b_2.clone()); - let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); - let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone())); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); - - // Test rewrite on a1_or_b2 or a2_or_b1 -> not change - let mut helper = CnfHelper::new(); - let expr1 = or(or(a_1, b_2), or(a_2, b_1)); - let expect = expr1.clone(); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); - } - - #[test] - fn test_rewrite_cnf_overflow() { - // in this situation: - // AND = (a=1 and b=2) - // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100 (a=1 or b=2) - // which cause size expansion. - - let mut expr1 = col("test1").eq(lit(1i64)); - let expr2 = col("test2").eq(lit(2i64)); - - for _i in 0..9 { - expr1 = expr1.clone().and(expr2.clone()); - } - let expr3 = expr1.clone(); - let expr = or(expr1, expr3); - let mut helper = CnfHelper::new(); - let res = expr.clone().rewrite(&mut helper).unwrap(); - assert_eq!(100, helper.current_count); - assert_eq!(res, expr); - assert!(helper.current_count >= helper.max_count); - } }