diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index a1d1806f9189..73fe8574a627 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping - Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") + Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") Inner Join: customer.c_nationkey = n2.n_nationkey Inner Join: supplier.s_nationkey = n1.n_nationkey Inner Join: orders.o_custkey = customer.c_custkey diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 2ff4947b3214..1ba8cf7ac42e 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1468,10 +1468,15 @@ async fn reduce_left_join_2() -> Result<()> { .expect(&msg); let state = ctx.state(); let plan = state.optimize(&plan)?; + + // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` + // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` + // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 6396f1fbfd6c..71c996afd6a6 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -14,6 +14,7 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan +use crate::utils::{split_binary_owned, CnfHelper}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ @@ -28,6 +29,7 @@ use datafusion_expr::{ utils::{expr_to_columns, exprlist_to_columns, from_plan}, Expr, Operator, TableProviderFilterPushDown, }; +use log::error; use std::collections::{HashMap, HashSet}; use std::iter::once; @@ -70,6 +72,7 @@ type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); struct State { // (predicate, columns on the predicate) filters: Vec, + use_cnf_rewrite: bool, } impl State { @@ -80,6 +83,11 @@ impl State { .zip(predicates.1) .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone()))) } + + fn with_cnf_rewrite(mut self) -> Self { + self.use_cnf_rewrite = true; + self + } } /// returns all predicates in `state` that depend on any of `used_columns` @@ -457,7 +465,7 @@ fn optimize_join( // Build new filter states using pushable predicates // from current optimizer states and from ON clause. // Then recursively call optimization for both join inputs - let mut left_state = State { filters: vec![] }; + let mut left_state = State::default(); left_state.append_predicates(to_left); left_state.append_predicates(on_to_left); or_to_left @@ -472,7 +480,7 @@ fn optimize_join( .for_each(|(expr, cols)| left_state.filters.push((expr, cols))); let left = optimize(left, left_state)?; - let mut right_state = State { filters: vec![] }; + let mut right_state = State::default(); right_state.append_predicates(to_right); right_state.append_predicates(on_to_right); or_to_right @@ -530,14 +538,26 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let predicates = utils::split_conjunction(filter.predicate()); + let predicates = if state.use_cnf_rewrite { + let filter_cnf = + filter.predicate().clone().rewrite(&mut CnfHelper::new()); + match filter_cnf { + Ok(ref expr) => split_binary_owned(expr.clone(), Operator::And), + Err(e) => { + error!("Fail at CnfHelper rewrite: {}.", e); + split_binary_owned(filter.predicate().clone(), Operator::And) + } + } + } else { + split_binary_owned(filter.predicate().clone(), Operator::And) + }; predicates .into_iter() .try_for_each::<_, Result<()>>(|predicate| { let mut columns: HashSet = HashSet::new(); - expr_to_columns(predicate, &mut columns)?; - state.filters.push((predicate.clone(), columns)); + expr_to_columns(&predicate, &mut columns)?; + state.filters.push((predicate, columns)); Ok(()) })?; @@ -797,7 +817,7 @@ impl OptimizerRule for FilterPushDown { plan: &LogicalPlan, _: &mut OptimizerConfig, ) -> Result { - optimize(plan, State::default()) + optimize(plan, State::default().with_cnf_rewrite()) } } @@ -953,6 +973,30 @@ mod tests { Ok(()) } + #[test] + fn filter_keep_partial_agg() -> Result<()> { + let table_scan = test_table_scan()?; + let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); + let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); + let filter = f1.or(f2); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? + .filter(filter)? + .build()?; + // filter of aggregate is after aggregation since they are non-commutative + // (c =1 AND b > 2) OR (c = 1 AND b > 3) + // rewrite to CNF + // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) + + let expected = "\ + Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ + \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> { @@ -2344,7 +2388,7 @@ mod tests { .filter(filter)? .build()?; - let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ + let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 130df3e0e6ef..dd599c89368a 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use datafusion_expr::{ and, col, @@ -99,20 +99,46 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// assert_eq!(split_conjunction_owned(expr), split); /// ``` pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_conjunction_owned_impl(expr, vec![]) + split_binary_owned(expr, Operator::And) } -fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { +/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// This is often used to "split" expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_optimizer::utils::split_binary_owned; +/// # use std::ops::Add; +/// // a=1 + b=2 +/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_binary_owned to split them +/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// ``` +pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { + split_binary_owned_impl(expr, op, vec![]) +} + +fn split_binary_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_owned_impl(*left, exprs); - split_conjunction_owned_impl(*right, exprs) + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_binary_owned_impl(*left, Operator::And, exprs); + split_binary_owned_impl(*right, Operator::And, exprs) } - Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs), + Expr::Alias(expr, _) => split_binary_owned_impl(*expr, Operator::And, exprs), other => { exprs.push(other); exprs @@ -120,6 +146,149 @@ fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { } } +/// Converts an expression to conjunctive normal form (CNF). +/// +/// The following expression is in CNF: +/// `(a OR b) AND (c OR d)` +/// The following is not in CNF: +/// `(a AND b) OR c`. +/// But could be rewrite to a CNF expression: +/// `(a OR c) AND (b OR c)`. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::expr_rewriter::ExprRewritable; +/// # use datafusion_optimizer::utils::CnfHelper; +/// // (a=1 AND b=2)OR c = 3 +/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// let expr2 = col("c").eq(lit(3)); +/// let expr = expr1.or(expr2); +/// +/// //(a=1 or c=3)AND(b=2 or c=3) +/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); +/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); +/// let expect = expr1.and(expr2); +/// // use split_conjunction_owned to split them +/// assert_eq!(expr.rewrite(& mut CnfHelper::new()).unwrap(), expect); +/// ``` +/// +pub struct CnfHelper { + max_count: usize, + current_count: usize, + exprs: Vec, + original_expr: Option, +} + +impl CnfHelper { + pub fn new() -> Self { + CnfHelper { + max_count: 50, + current_count: 0, + exprs: vec![], + original_expr: None, + } + } + + pub fn new_with_max_count(max_count: usize) -> Self { + CnfHelper { + max_count, + current_count: 0, + exprs: vec![], + original_expr: None, + } + } + + fn increment_and_check_overload(&mut self) -> bool { + self.current_count += 1; + self.current_count >= self.max_count + } +} + +impl ExprRewriter for CnfHelper { + fn pre_visit(&mut self, expr: &Expr) -> Result { + let is_root = self.original_expr.is_none(); + if is_root { + self.original_expr = Some(expr.clone()); + } + match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + match op { + Operator::And => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + } + // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) + Operator::Or => { + // Avoid create to much Expr like in tpch q19. + let lc = split_binary_owned(*left.clone(), Operator::Or) + .into_iter() + .flat_map(|e| split_binary_owned(e, Operator::And)) + .count(); + let rc = split_binary_owned(*right.clone(), Operator::Or) + .into_iter() + .flat_map(|e| split_binary_owned(e, Operator::And)) + .count(); + self.current_count += lc * rc - 1; + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + let left_and_split = + split_binary_owned(*left.clone(), Operator::And); + let right_and_split = + split_binary_owned(*right.clone(), Operator::And); + left_and_split.iter().for_each(|l| { + right_and_split.iter().for_each(|r| { + self.exprs.push(Expr::BinaryExpr(BinaryExpr { + left: Box::new(l.clone()), + op: Operator::Or, + right: Box::new(r.clone()), + })) + }) + }); + return Ok(RewriteRecursion::Mutate); + } + _ => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + self.exprs.push(expr.clone()); + return Ok(RewriteRecursion::Stop); + } + } + } + other => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + self.exprs.push(other.clone()); + return Ok(RewriteRecursion::Stop); + } + } + if is_root { + Ok(RewriteRecursion::Continue) + } else { + Ok(RewriteRecursion::Skip) + } + } + + fn mutate(&mut self, _expr: Expr) -> Result { + if self.current_count >= self.max_count { + Ok(self.original_expr.as_ref().unwrap().clone()) + } else { + Ok(conjunction(self.exprs.clone()) + .unwrap_or_else(|| self.original_expr.as_ref().unwrap().clone())) + } + } +} + +impl Default for CnfHelper { + fn default() -> Self { + Self::new() + } +} + /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. @@ -469,7 +638,7 @@ mod tests { use super::*; use arrow::datatypes::DataType; use datafusion_common::Column; - use datafusion_expr::{col, lit, utils::expr_to_columns}; + use datafusion_expr::{col, lit, or, utils::expr_to_columns}; use std::collections::HashSet; use std::ops::Add; @@ -507,6 +676,30 @@ mod tests { assert_eq!(result, vec![&expr]); } + #[test] + fn test_split_binary_owned() { + let expr = col("a"); + assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); + } + + #[test] + fn test_split_binary_owned_two() { + assert_eq!( + split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_binary_owned_different_op() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!( + // expr is connected by OR, but pass in AND + split_binary_owned(expr.clone(), Operator::And), + vec![expr] + ); + } + #[test] fn test_split_conjunction_owned() { let expr = col("a"); @@ -524,7 +717,7 @@ mod tests { #[test] fn test_split_conjunction_owned_alias() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias")),), vec![ col("a").eq(lit(5)), // no alias on b @@ -567,12 +760,6 @@ mod tests { assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); } - #[test] - fn test_split_conjunction_owned_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - #[test] fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); @@ -663,4 +850,84 @@ mod tests { "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" ) } + + #[test] + fn test_rewrite_cnf() { + let a_1 = col("a").eq(lit(1i64)); + let a_2 = col("a").eq(lit(2i64)); + + let b_1 = col("b").eq(lit(1i64)); + let b_2 = col("b").eq(lit(2i64)); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> not change + let mut helper = CnfHelper::new(); + let expr1 = and(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1) + let mut helper = CnfHelper::new(); + let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = and(a_1.clone(), b_2.clone()) + .and(a_2.clone()) + .and(b_1.clone()); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 -> not change + let mut helper = CnfHelper::new(); + let expr1 = or(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1 + let mut helper = CnfHelper::new(); + let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let a1_or_a2 = or(a_1.clone(), a_2.clone()); + let a1_or_b1 = or(a_1.clone(), b_1.clone()); + let b2_or_a2 = or(b_2.clone(), a_2.clone()); + let b2_or_b1 = or(b_2.clone(), b_1.clone()); + let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1) + let mut helper = CnfHelper::new(); + let a1_or_b2 = or(a_1.clone(), b_2.clone()); + let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone())); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 or a2_or_b1 -> not change + let mut helper = CnfHelper::new(); + let expr1 = or(or(a_1, b_2), or(a_2, b_1)); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + } + + #[test] + fn test_rewrite_cnf_overflow() { + // in this situation: + // AND = (a=1 and b=2) + // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100 (a=1 or b=2) + // which cause size expansion. + + let mut expr1 = col("test1").eq(lit(1i64)); + let expr2 = col("test2").eq(lit(2i64)); + + for _i in 0..9 { + expr1 = expr1.clone().and(expr2.clone()); + } + let expr3 = expr1.clone(); + let expr = or(expr1, expr3); + let mut helper = CnfHelper::new(); + let res = expr.clone().rewrite(&mut helper).unwrap(); + assert_eq!(100, helper.current_count); + assert_eq!(res, expr); + assert!(helper.current_count >= helper.max_count); + } }