From 280b3d1d5329e11cf9376582a72267cf25cf19ab Mon Sep 17 00:00:00 2001 From: yangjiang Date: Thu, 20 Oct 2022 15:52:01 +0800 Subject: [PATCH 01/14] Factorize common AND factors out of OR predicates to support filterPushDown as possible Signed-off-by: yangjiang --- benchmarks/expected-plans/q7.txt | 2 +- .../physical_plan/file_format/row_filter.rs | 4 +- datafusion/core/tests/sql/joins.rs | 7 +- datafusion/optimizer/src/filter_push_down.rs | 58 +++- datafusion/optimizer/src/utils.rs | 270 ++++++++++++++++-- 5 files changed, 311 insertions(+), 30 deletions(-) diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index a1d1806f9189..73fe8574a627 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping - Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") + Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") Inner Join: customer.c_nationkey = n2.n_nationkey Inner Join: supplier.s_nationkey = n1.n_nationkey Inner Join: orders.o_custkey = customer.c_custkey diff --git a/datafusion/core/src/physical_plan/file_format/row_filter.rs b/datafusion/core/src/physical_plan/file_format/row_filter.rs index dd9c8fb650fd..2ac55d368bf9 100644 --- a/datafusion/core/src/physical_plan/file_format/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/row_filter.rs @@ -22,7 +22,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, Operator}; use datafusion_optimizer::utils::split_conjunction_owned; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; @@ -253,7 +253,7 @@ pub fn build_row_filter( metadata: &ParquetMetaData, reorder_predicates: bool, ) -> Result> { - let predicates = split_conjunction_owned(expr); + let predicates = split_conjunction_owned(expr, Operator::And); let mut candidates: Vec = predicates .into_iter() diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 2ff4947b3214..1ba8cf7ac42e 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1468,10 +1468,15 @@ async fn reduce_left_join_2() -> Result<()> { .expect(&msg); let state = ctx.state(); let plan = state.optimize(&plan)?; + + // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` + // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` + // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 6396f1fbfd6c..61e6d1261273 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -14,6 +14,7 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan +use crate::utils::{split_conjunction_owned, CnfHelper}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ @@ -28,6 +29,7 @@ use datafusion_expr::{ utils::{expr_to_columns, exprlist_to_columns, from_plan}, Expr, Operator, TableProviderFilterPushDown, }; +use log::error; use std::collections::{HashMap, HashSet}; use std::iter::once; @@ -70,6 +72,7 @@ type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); struct State { // (predicate, columns on the predicate) filters: Vec, + use_cnf_rewrite: bool, } impl State { @@ -80,6 +83,11 @@ impl State { .zip(predicates.1) .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone()))) } + + fn with_cnf_rewrite(mut self) -> Self { + self.use_cnf_rewrite = true; + self + } } /// returns all predicates in `state` that depend on any of `used_columns` @@ -457,7 +465,7 @@ fn optimize_join( // Build new filter states using pushable predicates // from current optimizer states and from ON clause. // Then recursively call optimization for both join inputs - let mut left_state = State { filters: vec![] }; + let mut left_state = State::default(); left_state.append_predicates(to_left); left_state.append_predicates(on_to_left); or_to_left @@ -472,7 +480,7 @@ fn optimize_join( .for_each(|(expr, cols)| left_state.filters.push((expr, cols))); let left = optimize(left, left_state)?; - let mut right_state = State { filters: vec![] }; + let mut right_state = State::default(); right_state.append_predicates(to_right); right_state.append_predicates(on_to_right); or_to_right @@ -530,14 +538,26 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let predicates = utils::split_conjunction(filter.predicate()); + let predicates = if state.use_cnf_rewrite { + let filter_cnf = + filter.predicate().clone().rewrite(&mut CnfHelper::new()); + match filter_cnf { + Ok(ref expr) => split_conjunction_owned(expr.clone(), Operator::And), + Err(e) => { + error!("Fail at CnfHelper rewrite: {}.", e); + split_conjunction_owned(filter.predicate().clone(), Operator::And) + } + } + } else { + split_conjunction_owned(filter.predicate().clone(), Operator::And) + }; predicates .into_iter() .try_for_each::<_, Result<()>>(|predicate| { let mut columns: HashSet = HashSet::new(); - expr_to_columns(predicate, &mut columns)?; - state.filters.push((predicate.clone(), columns)); + expr_to_columns(&predicate, &mut columns)?; + state.filters.push((predicate, columns)); Ok(()) })?; @@ -797,7 +817,7 @@ impl OptimizerRule for FilterPushDown { plan: &LogicalPlan, _: &mut OptimizerConfig, ) -> Result { - optimize(plan, State::default()) + optimize(plan, State::default().with_cnf_rewrite()) } } @@ -953,6 +973,30 @@ mod tests { Ok(()) } + #[test] + fn filter_keep_partial_agg() -> Result<()> { + let table_scan = test_table_scan()?; + let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); + let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); + let filter = f1.or(f2); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? + .filter(filter)? + .build()?; + // filter of aggregate is after aggregation since they are non-commutative + // (c =1 AND b > 2) OR (c = 1 AND b > 3) + // rewrite to CNF + // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) + + let expected = "\ + Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ + \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> { @@ -2344,7 +2388,7 @@ mod tests { .filter(filter)? .build()?; - let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ + let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 130df3e0e6ef..aaca27e092d3 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use datafusion_expr::{ and, col, @@ -84,7 +84,7 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// /// # Example /// ``` -/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::{col, lit, Operator}; /// # use datafusion_optimizer::utils::split_conjunction_owned; /// // a=1 AND b=2 /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); @@ -96,23 +96,23 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// ]; /// /// // use split_conjunction_owned to split them -/// assert_eq!(split_conjunction_owned(expr), split); +/// assert_eq!(split_conjunction_owned(expr, Operator::And), split); /// ``` -pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_conjunction_owned_impl(expr, vec![]) +pub fn split_conjunction_owned(expr: Expr, op: Operator) -> Vec { + split_conjunction_owned_impl(expr, op, vec![]) } -fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { +fn split_conjunction_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_owned_impl(*left, exprs); - split_conjunction_owned_impl(*right, exprs) + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_conjunction_owned_impl(*left, Operator::And, exprs); + split_conjunction_owned_impl(*right, Operator::And, exprs) } - Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs), + Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, Operator::And, exprs), other => { exprs.push(other); exprs @@ -120,6 +120,149 @@ fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { } } +/// Converts an expression to conjunctive normal form (CNF). +/// +/// The following expression is in CNF: +/// `(a OR b) AND (c OR d)` +/// The following is not in CNF: +/// `(a AND b) OR c`. +/// But could be rewrite to a CNF expression: +/// `(a OR c) AND (b OR c)`. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::expr_rewriter::ExprRewritable; +/// # use datafusion_optimizer::utils::CnfHelper; +/// // (a=1 AND b=2)OR c = 3 +/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// let expr2 = col("c").eq(lit(3)); +/// let expr = expr1.or(expr2); +/// +/// //(a=1 or c=3)AND(b=2 or c=3) +/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); +/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); +/// let expect = expr1.and(expr2); +/// // use split_conjunction_owned to split them +/// assert_eq!(expr.rewrite(& mut CnfHelper::new()).unwrap(), expect); +/// ``` +/// +pub struct CnfHelper { + max_count: usize, + current_count: usize, + exprs: Vec, + original_expr: Option, +} + +impl CnfHelper { + pub fn new() -> Self { + CnfHelper { + max_count: 50, + current_count: 0, + exprs: vec![], + original_expr: None, + } + } + + pub fn new_with_max_count(max_count: usize) -> Self { + CnfHelper { + max_count, + current_count: 0, + exprs: vec![], + original_expr: None, + } + } + + fn increment_and_check_overload(&mut self) -> bool { + self.current_count += 1; + self.current_count >= self.max_count + } +} + +impl ExprRewriter for CnfHelper { + fn pre_visit(&mut self, expr: &Expr) -> Result { + let is_root = self.original_expr.is_none(); + if is_root { + self.original_expr = Some(expr.clone()); + } + match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + match op { + Operator::And => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + } + // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) + Operator::Or => { + // Avoid create to much Expr like in tpch q19. + let lc = split_conjunction_owned(*left.clone(), Operator::Or) + .into_iter() + .flat_map(|e| split_conjunction_owned(e, Operator::And)) + .count(); + let rc = split_conjunction_owned(*right.clone(), Operator::Or) + .into_iter() + .flat_map(|e| split_conjunction_owned(e, Operator::And)) + .count(); + self.current_count += lc * rc - 1; + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + let left_and_split = + split_conjunction_owned(*left.clone(), Operator::And); + let right_and_split = + split_conjunction_owned(*right.clone(), Operator::And); + left_and_split.iter().for_each(|l| { + right_and_split.iter().for_each(|r| { + self.exprs.push(Expr::BinaryExpr(BinaryExpr { + left: Box::new(l.clone()), + op: Operator::Or, + right: Box::new(r.clone()), + })) + }) + }); + return Ok(RewriteRecursion::Mutate); + } + _ => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + self.exprs.push(expr.clone()); + return Ok(RewriteRecursion::Stop); + } + } + } + other => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + self.exprs.push(other.clone()); + return Ok(RewriteRecursion::Stop); + } + } + if is_root { + Ok(RewriteRecursion::Continue) + } else { + Ok(RewriteRecursion::Skip) + } + } + + fn mutate(&mut self, _expr: Expr) -> Result { + if self.current_count >= self.max_count { + Ok(self.original_expr.as_ref().unwrap().clone()) + } else { + Ok(conjunction(self.exprs.clone()) + .unwrap_or_else(|| self.original_expr.as_ref().unwrap().clone())) + } + } +} + +impl Default for CnfHelper { + fn default() -> Self { + Self::new() + } +} + /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. @@ -469,7 +612,7 @@ mod tests { use super::*; use arrow::datatypes::DataType; use datafusion_common::Column; - use datafusion_expr::{col, lit, utils::expr_to_columns}; + use datafusion_expr::{col, lit, or, utils::expr_to_columns}; use std::collections::HashSet; use std::ops::Add; @@ -510,13 +653,16 @@ mod tests { #[test] fn test_split_conjunction_owned() { let expr = col("a"); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + assert_eq!( + split_conjunction_owned(expr.clone(), Operator::And), + vec![expr] + ); } #[test] fn test_split_conjunction_owned_two() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), vec![col("a").eq(lit(5)), col("b")] ); } @@ -524,7 +670,10 @@ mod tests { #[test] fn test_split_conjunction_owned_alias() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), + split_conjunction_owned( + col("a").eq(lit(5)).and(col("b").alias("the_alias")), + Operator::And + ), vec![ col("a").eq(lit(5)), // no alias on b @@ -570,7 +719,10 @@ mod tests { #[test] fn test_split_conjunction_owned_or() { let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + assert_eq!( + split_conjunction_owned(expr.clone(), Operator::And), + vec![expr] + ); } #[test] @@ -663,4 +815,84 @@ mod tests { "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" ) } + + #[test] + fn test_rewrite_cnf() { + let a_1 = col("a").eq(lit(1i64)); + let a_2 = col("a").eq(lit(2i64)); + + let b_1 = col("b").eq(lit(1i64)); + let b_2 = col("b").eq(lit(2i64)); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> not change + let mut helper = CnfHelper::new(); + let expr1 = and(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1) + let mut helper = CnfHelper::new(); + let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = and(a_1.clone(), b_2.clone()) + .and(a_2.clone()) + .and(b_1.clone()); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 -> not change + let mut helper = CnfHelper::new(); + let expr1 = or(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1 + let mut helper = CnfHelper::new(); + let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let a1_or_a2 = or(a_1.clone(), a_2.clone()); + let a1_or_b1 = or(a_1.clone(), b_1.clone()); + let b2_or_a2 = or(b_2.clone(), a_2.clone()); + let b2_or_b1 = or(b_2.clone(), b_1.clone()); + let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1) + let mut helper = CnfHelper::new(); + let a1_or_b2 = or(a_1.clone(), b_2.clone()); + let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone())); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 or a2_or_b1 -> not change + let mut helper = CnfHelper::new(); + let expr1 = or(or(a_1, b_2), or(a_2, b_1)); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + } + + #[test] + fn test_rewrite_cnf_overflow() { + // in this situation: + // AND = (a=1 and b=2) + // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100 (a=1 or b=2) + // which cause size expansion. + + let mut expr1 = col("test1").eq(lit(1i64)); + let expr2 = col("test2").eq(lit(2i64)); + + for _i in 0..9 { + expr1 = expr1.clone().and(expr2.clone()); + } + let expr3 = expr1.clone(); + let expr = or(expr1, expr3); + let mut helper = CnfHelper::new(); + let res = expr.clone().rewrite(&mut helper).unwrap(); + assert_eq!(100, helper.current_count); + assert_eq!(res, expr); + assert!(helper.current_count >= helper.max_count); + } } From 0faf77632b9047d84008bc4cb858ee3bdbc2f598 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 20 Oct 2022 14:21:18 -0400 Subject: [PATCH 02/14] add split_binary_owned rather than change signature of split_conjuction_owned --- .../physical_plan/file_format/row_filter.rs | 4 +- datafusion/optimizer/src/filter_push_down.rs | 8 +- datafusion/optimizer/src/utils.rs | 95 +++++++++++++------ 3 files changed, 71 insertions(+), 36 deletions(-) diff --git a/datafusion/core/src/physical_plan/file_format/row_filter.rs b/datafusion/core/src/physical_plan/file_format/row_filter.rs index 2ac55d368bf9..dd9c8fb650fd 100644 --- a/datafusion/core/src/physical_plan/file_format/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/row_filter.rs @@ -22,7 +22,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; -use datafusion_expr::{Expr, Operator}; +use datafusion_expr::Expr; use datafusion_optimizer::utils::split_conjunction_owned; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; @@ -253,7 +253,7 @@ pub fn build_row_filter( metadata: &ParquetMetaData, reorder_predicates: bool, ) -> Result> { - let predicates = split_conjunction_owned(expr, Operator::And); + let predicates = split_conjunction_owned(expr); let mut candidates: Vec = predicates .into_iter() diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 61e6d1261273..71c996afd6a6 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -14,7 +14,7 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan -use crate::utils::{split_conjunction_owned, CnfHelper}; +use crate::utils::{split_binary_owned, CnfHelper}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ @@ -542,14 +542,14 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { let filter_cnf = filter.predicate().clone().rewrite(&mut CnfHelper::new()); match filter_cnf { - Ok(ref expr) => split_conjunction_owned(expr.clone(), Operator::And), + Ok(ref expr) => split_binary_owned(expr.clone(), Operator::And), Err(e) => { error!("Fail at CnfHelper rewrite: {}.", e); - split_conjunction_owned(filter.predicate().clone(), Operator::And) + split_binary_owned(filter.predicate().clone(), Operator::And) } } } else { - split_conjunction_owned(filter.predicate().clone(), Operator::And) + split_binary_owned(filter.predicate().clone(), Operator::And) }; predicates diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index aaca27e092d3..dd599c89368a 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -84,7 +84,7 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// /// # Example /// ``` -/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_expr::{col, lit}; /// # use datafusion_optimizer::utils::split_conjunction_owned; /// // a=1 AND b=2 /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); @@ -96,23 +96,49 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// ]; /// /// // use split_conjunction_owned to split them -/// assert_eq!(split_conjunction_owned(expr, Operator::And), split); +/// assert_eq!(split_conjunction_owned(expr), split); /// ``` -pub fn split_conjunction_owned(expr: Expr, op: Operator) -> Vec { - split_conjunction_owned_impl(expr, op, vec![]) +pub fn split_conjunction_owned(expr: Expr) -> Vec { + split_binary_owned(expr, Operator::And) } -fn split_conjunction_owned_impl( +/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// This is often used to "split" expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_optimizer::utils::split_binary_owned; +/// # use std::ops::Add; +/// // a=1 + b=2 +/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_binary_owned to split them +/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// ``` +pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { + split_binary_owned_impl(expr, op, vec![]) +} + +fn split_binary_owned_impl( expr: Expr, operator: Operator, mut exprs: Vec, ) -> Vec { match expr { Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { - let exprs = split_conjunction_owned_impl(*left, Operator::And, exprs); - split_conjunction_owned_impl(*right, Operator::And, exprs) + let exprs = split_binary_owned_impl(*left, Operator::And, exprs); + split_binary_owned_impl(*right, Operator::And, exprs) } - Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, Operator::And, exprs), + Expr::Alias(expr, _) => split_binary_owned_impl(*expr, Operator::And, exprs), other => { exprs.push(other); exprs @@ -196,22 +222,22 @@ impl ExprRewriter for CnfHelper { // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) Operator::Or => { // Avoid create to much Expr like in tpch q19. - let lc = split_conjunction_owned(*left.clone(), Operator::Or) + let lc = split_binary_owned(*left.clone(), Operator::Or) .into_iter() - .flat_map(|e| split_conjunction_owned(e, Operator::And)) + .flat_map(|e| split_binary_owned(e, Operator::And)) .count(); - let rc = split_conjunction_owned(*right.clone(), Operator::Or) + let rc = split_binary_owned(*right.clone(), Operator::Or) .into_iter() - .flat_map(|e| split_conjunction_owned(e, Operator::And)) + .flat_map(|e| split_binary_owned(e, Operator::And)) .count(); self.current_count += lc * rc - 1; if self.increment_and_check_overload() { return Ok(RewriteRecursion::Mutate); } let left_and_split = - split_conjunction_owned(*left.clone(), Operator::And); + split_binary_owned(*left.clone(), Operator::And); let right_and_split = - split_conjunction_owned(*right.clone(), Operator::And); + split_binary_owned(*right.clone(), Operator::And); left_and_split.iter().for_each(|l| { right_and_split.iter().for_each(|r| { self.exprs.push(Expr::BinaryExpr(BinaryExpr { @@ -651,18 +677,39 @@ mod tests { } #[test] - fn test_split_conjunction_owned() { + fn test_split_binary_owned() { let expr = col("a"); + assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); + } + + #[test] + fn test_split_binary_owned_two() { + assert_eq!( + split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_binary_owned_different_op() { + let expr = col("a").eq(lit(5)).or(col("b")); assert_eq!( - split_conjunction_owned(expr.clone(), Operator::And), + // expr is connected by OR, but pass in AND + split_binary_owned(expr.clone(), Operator::And), vec![expr] ); } + #[test] + fn test_split_conjunction_owned() { + let expr = col("a"); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + #[test] fn test_split_conjunction_owned_two() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), vec![col("a").eq(lit(5)), col("b")] ); } @@ -670,10 +717,7 @@ mod tests { #[test] fn test_split_conjunction_owned_alias() { assert_eq!( - split_conjunction_owned( - col("a").eq(lit(5)).and(col("b").alias("the_alias")), - Operator::And - ), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias")),), vec![ col("a").eq(lit(5)), // no alias on b @@ -716,15 +760,6 @@ mod tests { assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); } - #[test] - fn test_split_conjunction_owned_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!( - split_conjunction_owned(expr.clone(), Operator::And), - vec![expr] - ); - } - #[test] fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); From 9f720c23697db3d0dae46a98bbdd884c1594f637 Mon Sep 17 00:00:00 2001 From: yangjiang Date: Sat, 22 Oct 2022 21:44:46 +0800 Subject: [PATCH 03/14] add `split_binary` and avoid some clone calls Signed-off-by: yangjiang --- datafusion/optimizer/src/filter_push_down.rs | 3 ++ datafusion/optimizer/src/utils.rs | 53 ++++++++++++++------ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 71c996afd6a6..77e99ecad123 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -84,6 +84,9 @@ impl State { .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone()))) } + // set `true` means split the filter-exprs into CNF (see `CnfHelper`) + // to push more filter conditions down, it may cause filter-exprs size + // expansion. fn with_cnf_rewrite(mut self) -> Self { self.use_cnf_rewrite = true; self diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index dd599c89368a..839e3f0d7f1d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -102,7 +102,7 @@ pub fn split_conjunction_owned(expr: Expr) -> Vec { split_binary_owned(expr, Operator::And) } -/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` /// /// This is often used to "split" expressions such as `col1 = 5 /// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; @@ -146,6 +146,31 @@ fn split_binary_owned_impl( } } +/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// See [`split_binary_owned`] for more details and an example. +pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { + split_binary_impl(expr, op, vec![]) +} + +fn split_binary_impl<'a>( + expr: &'a Expr, + operator: Operator, + mut exprs: Vec<&'a Expr>, +) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + let exprs = split_conjunction_impl(left, exprs); + split_conjunction_impl(right, exprs) + } + Expr::Alias(expr, _) => split_conjunction_impl(expr, exprs), + other => { + exprs.push(other); + exprs + } + } +} + /// Converts an expression to conjunctive normal form (CNF). /// /// The following expression is in CNF: @@ -222,31 +247,29 @@ impl ExprRewriter for CnfHelper { // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) Operator::Or => { // Avoid create to much Expr like in tpch q19. - let lc = split_binary_owned(*left.clone(), Operator::Or) + let lc = split_binary(left, Operator::Or) .into_iter() - .flat_map(|e| split_binary_owned(e, Operator::And)) + .flat_map(|e| split_binary(e, Operator::And)) .count(); - let rc = split_binary_owned(*right.clone(), Operator::Or) + let rc = split_binary(right, Operator::Or) .into_iter() - .flat_map(|e| split_binary_owned(e, Operator::And)) + .flat_map(|e| split_binary(e, Operator::And)) .count(); self.current_count += lc * rc - 1; if self.increment_and_check_overload() { return Ok(RewriteRecursion::Mutate); } - let left_and_split = - split_binary_owned(*left.clone(), Operator::And); - let right_and_split = - split_binary_owned(*right.clone(), Operator::And); - left_and_split.iter().for_each(|l| { - right_and_split.iter().for_each(|r| { + let left_and_split = split_binary(left, Operator::And); + let right_and_split = split_binary(right, Operator::And); + for l in left_and_split { + for &r in &right_and_split { self.exprs.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(l.clone()), op: Operator::Or, right: Box::new(r.clone()), - })) - }) - }); + })); + } + } return Ok(RewriteRecursion::Mutate); } _ => { @@ -278,7 +301,7 @@ impl ExprRewriter for CnfHelper { Ok(self.original_expr.as_ref().unwrap().clone()) } else { Ok(conjunction(self.exprs.clone()) - .unwrap_or_else(|| self.original_expr.as_ref().unwrap().clone())) + .unwrap_or_else(|| self.original_expr.take().unwrap())) } } } From 7cd016ce170bfc458d5c317b93dc3f3b0e9c5ca2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 08:37:00 -0400 Subject: [PATCH 04/14] Update plans --- benchmarks/expected-plans/q7.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index 73fe8574a627..fad02c09881c 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping - Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") + Filter: (n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE")) AND (n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY")) Inner Join: customer.c_nationkey = n2.n_nationkey Inner Join: supplier.s_nationkey = n1.n_nationkey Inner Join: orders.o_custkey = customer.c_custkey From e3ff3f7cee2f4d902d75ea0163e64481fc53859e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 08:41:01 -0400 Subject: [PATCH 05/14] Simplify tests --- datafusion/optimizer/src/utils.rs | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 9f5c3ad69a90..ee15b8dc7c08 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -866,6 +866,11 @@ mod tests { ) } + fn cnf_rewrite(expr: Expr) -> Expr { + let mut helper = CnfHelper::new(); + expr.rewrite(&mut helper).unwrap() + } + #[test] fn test_rewrite_cnf() { let a_1 = col("a").eq(lit(1i64)); @@ -875,53 +880,41 @@ mod tests { let b_2 = col("b").eq(lit(2i64)); // Test rewrite on a1_and_b2 and a2_and_b1 -> not change - let mut helper = CnfHelper::new(); let expr1 = and(a_1.clone(), b_2.clone()); let expect = expr1.clone(); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); + assert_eq!(expect, cnf_rewrite(expr1)); // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1) - let mut helper = CnfHelper::new(); let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); let expect = and(a_1.clone(), b_2.clone()) .and(a_2.clone()) .and(b_1.clone()); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); + assert_eq!(expect, cnf_rewrite(expr1)); // Test rewrite on a1_or_b2 -> not change - let mut helper = CnfHelper::new(); let expr1 = or(a_1.clone(), b_2.clone()); let expect = expr1.clone(); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); + assert_eq!(expect, cnf_rewrite(expr1)); // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1 - let mut helper = CnfHelper::new(); let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); let a1_or_a2 = or(a_1.clone(), a_2.clone()); let a1_or_b1 = or(a_1.clone(), b_1.clone()); let b2_or_a2 = or(b_2.clone(), a_2.clone()); let b2_or_b1 = or(b_2.clone(), b_1.clone()); let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); + assert_eq!(expect, cnf_rewrite(expr1)); // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1) - let mut helper = CnfHelper::new(); let a1_or_b2 = or(a_1.clone(), b_2.clone()); let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone())); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); + assert_eq!(expect, cnf_rewrite(expr1)); // Test rewrite on a1_or_b2 or a2_or_b1 -> not change - let mut helper = CnfHelper::new(); let expr1 = or(or(a_1, b_2), or(a_2, b_1)); let expect = expr1.clone(); - let res = expr1.rewrite(&mut helper).unwrap(); - assert_eq!(expect, res); + assert_eq!(expect, cnf_rewrite(expr1)); } #[test] From d70c7303b70436b1fd7620e392f8b9117fc027b0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 08:45:01 -0400 Subject: [PATCH 06/14] Update tests --- datafusion/optimizer/src/filter_push_down.rs | 23 ++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index b7794c261a9d..4276138ef93d 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -992,10 +992,10 @@ mod tests { // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) let expected = "\ - Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ - \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ - \n TableScan: test"; + Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ + \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ + \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); Ok(()) } @@ -2391,13 +2391,14 @@ mod tests { .filter(filter)? .build()?; - let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\ - \n CrossJoin:\ - \n Projection: test.a, test.b, test.c\ - \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ - \n TableScan: test\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; + let expected = "\ + Filter: (test.a = d OR test.b = e) AND (test.a = d OR test.c < UInt32(10)) AND (test.b > UInt32(1) OR test.b = e)\ + \n CrossJoin:\ + \n Projection: test.a, test.b, test.c\ + \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ + \n TableScan: test\ + \n Projection: test1.a AS d, test1.a AS e\ + \n TableScan: test1"; assert_optimized_plan_eq(&plan, expected); Ok(()) } From 564974eed168c9891433632519cf44df3a1fd219 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 10:17:31 -0400 Subject: [PATCH 07/14] Rewrite CNF without recursion --- datafusion/optimizer/src/utils.rs | 142 ++++++++++++++++++++++++++++-- 1 file changed, 135 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index ee15b8dc7c08..7a56967d9f99 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -29,7 +29,7 @@ use datafusion_expr::{ utils::from_plan, Expr, Operator, }; -use std::collections::HashSet; +use std::collections::{HashSet, VecDeque}; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -135,10 +135,10 @@ fn split_binary_owned_impl( ) -> Vec { match expr { Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { - let exprs = split_binary_owned_impl(*left, Operator::And, exprs); - split_binary_owned_impl(*right, Operator::And, exprs) + let exprs = split_binary_owned_impl(*left, operator, exprs); + split_binary_owned_impl(*right, operator, exprs) } - Expr::Alias(expr, _) => split_binary_owned_impl(*expr, Operator::And, exprs), + Expr::Alias(expr, _) => split_binary_owned_impl(*expr, operator, exprs), other => { exprs.push(other); exprs @@ -230,6 +230,81 @@ impl CnfHelper { } } +// Given some number of lists of Exprs, returns a set of expressions +// where there is one expression from each that there is a single +// element from each of the +fn permutations(mut exprs: VecDeque>) -> Vec> { + let first = if let Some(first) = exprs.pop_front() { + first + } else { + return vec![]; + }; + + // base case: + if exprs.is_empty() { + first.into_iter().map(|e| vec![e]).collect() + } else { + first + .iter() + .flat_map(|expr| { + permutations(exprs.clone()) + .into_iter() + .map(|expr_list| { + // Create [expr, ...] for each permutation + std::iter::once(expr.clone()) + .chain(expr_list.into_iter()) + .collect::>() + }) + .collect::>>() + }) + .collect() + } +} + +fn cnf_rewrite(expr: Expr) -> Expr { + println!("AAL input:\n\n{}", expr); + + // Find all exprs joined by OR + let disjuncts = split_binary_owned(expr, Operator::Or); + println!("AAL disjuncts:\n\n{:#?}", disjuncts); + + // For each expr, find joined by AND + // A OR B OR C --> split each A, B and C + let disjunct_conjuncts: VecDeque> = disjuncts + .into_iter() + .map(|e| split_binary_owned(e, Operator::And)) + .collect::>(); + + // now we want to distribute the item + // A AND (B OR C) + // --> + // (A OR B) AND (A OR C) + + println!("AAL disjunct conjuncts:\n\n{:#?}", disjunct_conjuncts); + + // Decide if we want to distribute the clauses. Heuristic is + // chosen to avoid creating huge predicates + if disjunct_conjuncts.len() == 2 + && disjunct_conjuncts[0].len() == 2 + && disjunct_conjuncts[1].len() == 2 + { + // form the OR clauses( A OR B OR C ..) + let or_clauses = permutations(disjunct_conjuncts) + .into_iter() + .map(|exprs| disjunction(exprs).unwrap()); + conjunction(or_clauses).unwrap() + } + // otherwise reassemble the expression + else { + disjunction( + disjunct_conjuncts + .into_iter() + .map(|exprs| conjunction(exprs).unwrap()), + ) + .unwrap() + } +} + impl ExprRewriter for CnfHelper { fn pre_visit(&mut self, expr: &Expr) -> Result { let is_root = self.original_expr.is_none(); @@ -866,9 +941,62 @@ mod tests { ) } - fn cnf_rewrite(expr: Expr) -> Expr { - let mut helper = CnfHelper::new(); - expr.rewrite(&mut helper).unwrap() + // fn cnf_rewrite(expr: Expr) -> Expr { + // let mut helper = CnfHelper::new(); + // expr.rewrite(&mut helper).unwrap() + // } + + #[test] + fn test_permutations() { + assert_eq!(permutations(vec![].into()), vec![] as Vec>) + } + + #[test] + fn test_permutations_one() { + // [[a]] --> [[a]] + assert_eq!( + permutations(vec![vec![col("a")]].into()), + vec![vec![col("a")]] + ) + } + + #[test] + fn test_permutations_two() { + // [[a, b]] --> [[a], [b]] + assert_eq!( + permutations(vec![vec![col("a"), col("b")]].into()), + vec![vec![col("a")], vec![col("b")]] + ) + } + + #[test] + fn test_permutations_two_and_one() { + // [[a, b], [c]] --> [[a, c], [b, c]] + assert_eq!( + permutations(vec![vec![col("a"), col("b")], vec![col("c")]].into()), + vec![vec![col("a"), col("c")], vec![col("b"), col("c")]] + ) + } + + #[test] + fn test_permutations_two_and_one_and_two() { + // [[a, b], [c], [d, e]] --> [[a, c, d], [a, c, e], [b, c, d], [b, c, e]] + assert_eq!( + permutations( + vec![ + vec![col("a"), col("b")], + vec![col("c")], + vec![col("d"), col("e")] + ] + .into() + ), + vec![ + vec![col("a"), col("c"), col("d")], + vec![col("a"), col("c"), col("e")], + vec![col("b"), col("c"), col("d")], + vec![col("b"), col("c"), col("e")], + ] + ) } #[test] From 97afb9d354a944e2dc076bbdb69bc4a84b01f904 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 10:21:56 -0400 Subject: [PATCH 08/14] Change heuristic --- datafusion/optimizer/src/utils.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 7a56967d9f99..1c96573d83fb 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -284,9 +284,9 @@ fn cnf_rewrite(expr: Expr) -> Expr { // Decide if we want to distribute the clauses. Heuristic is // chosen to avoid creating huge predicates - if disjunct_conjuncts.len() == 2 - && disjunct_conjuncts[0].len() == 2 - && disjunct_conjuncts[1].len() == 2 + let total_permutations = disjunct_conjuncts.iter().fold(1, |sz, exprs| sz * exprs.len()); + + if total_permutations < 10 { // form the OR clauses( A OR B OR C ..) let or_clauses = permutations(disjunct_conjuncts) From 2fd4166ae98b93deb1cf46d431862dc913eb685c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 11:01:26 -0400 Subject: [PATCH 09/14] Keep cleaning up --- datafusion/optimizer/src/utils.rs | 69 +++++++++++++++---------------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 1c96573d83fb..41f844b5e3ac 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -171,33 +171,6 @@ fn split_binary_impl<'a>( } } -/// Converts an expression to conjunctive normal form (CNF). -/// -/// The following expression is in CNF: -/// `(a OR b) AND (c OR d)` -/// The following is not in CNF: -/// `(a AND b) OR c`. -/// But could be rewrite to a CNF expression: -/// `(a OR c) AND (b OR c)`. -/// -/// # Example -/// ``` -/// # use datafusion_expr::{col, lit}; -/// # use datafusion_expr::expr_rewriter::ExprRewritable; -/// # use datafusion_optimizer::utils::CnfHelper; -/// // (a=1 AND b=2)OR c = 3 -/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); -/// let expr2 = col("c").eq(lit(3)); -/// let expr = expr1.or(expr2); -/// -/// //(a=1 or c=3)AND(b=2 or c=3) -/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); -/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); -/// let expect = expr1.and(expr2); -/// // use split_conjunction_owned to split them -/// assert_eq!(expr.rewrite(& mut CnfHelper::new()).unwrap(), expect); -/// ``` -/// pub struct CnfHelper { max_count: usize, current_count: usize, @@ -261,7 +234,33 @@ fn permutations(mut exprs: VecDeque>) -> Vec> { } } -fn cnf_rewrite(expr: Expr) -> Expr { +/// Converts an expression to conjunctive normal form (CNF). +/// +/// The following expression is in CNF: +/// `(a OR b) AND (c OR d)` +/// +/// The following is not in CNF: +/// `(a AND b) OR c`. +/// +/// But could be rewrite to a CNF expression: +/// `(a OR c) AND (b OR c)`. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_optimizer::utils::cnf_rewrite; +/// // (a=1 AND b=2)OR c = 3 +/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// let expr2 = col("c").eq(lit(3)); +/// let expr = expr1.or(expr2); +/// +/// //(a=1 or c=3)AND(b=2 or c=3) +/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); +/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); +/// let expect = expr1.and(expr2); +/// assert_eq!(expect, cnf_rewrite(expr)); +/// ``` +pub fn cnf_rewrite(expr: Expr) -> Expr { println!("AAL input:\n\n{}", expr); // Find all exprs joined by OR @@ -284,10 +283,11 @@ fn cnf_rewrite(expr: Expr) -> Expr { // Decide if we want to distribute the clauses. Heuristic is // chosen to avoid creating huge predicates - let total_permutations = disjunct_conjuncts.iter().fold(1, |sz, exprs| sz * exprs.len()); + let total_permutations = disjunct_conjuncts + .iter() + .fold(1, |sz, exprs| sz * exprs.len()); - if total_permutations < 10 - { + if total_permutations < 10 { // form the OR clauses( A OR B OR C ..) let or_clauses = permutations(disjunct_conjuncts) .into_iter() @@ -1060,10 +1060,7 @@ mod tests { } let expr3 = expr1.clone(); let expr = or(expr1, expr3); - let mut helper = CnfHelper::new(); - let res = expr.clone().rewrite(&mut helper).unwrap(); - assert_eq!(100, helper.current_count); - assert_eq!(res, expr); - assert!(helper.current_count >= helper.max_count); + + assert_eq!(expr, cnf_rewrite(expr.clone())); } } From 3f71f155d6201fa90bcbdc4269095da241faecaa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 11:27:43 -0400 Subject: [PATCH 10/14] tests pass --- datafusion/optimizer/src/utils.rs | 79 +++++++++++++++++-------------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 41f844b5e3ac..38d1874e348f 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -160,10 +160,10 @@ fn split_binary_impl<'a>( ) -> Vec<&'a Expr> { match expr { Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { - let exprs = split_conjunction_impl(left, exprs); - split_conjunction_impl(right, exprs) + let exprs = split_binary_impl(left, operator, exprs); + split_binary_impl(right, operator, exprs) } - Expr::Alias(expr, _) => split_conjunction_impl(expr, exprs), + Expr::Alias(expr, _) => split_binary_impl(expr, operator, exprs), other => { exprs.push(other); exprs @@ -206,7 +206,7 @@ impl CnfHelper { // Given some number of lists of Exprs, returns a set of expressions // where there is one expression from each that there is a single // element from each of the -fn permutations(mut exprs: VecDeque>) -> Vec> { +fn permutations(mut exprs: VecDeque>) -> Vec> { let first = if let Some(first) = exprs.pop_front() { first } else { @@ -226,9 +226,9 @@ fn permutations(mut exprs: VecDeque>) -> Vec> { // Create [expr, ...] for each permutation std::iter::once(expr.clone()) .chain(expr_list.into_iter()) - .collect::>() + .collect::>() }) - .collect::>>() + .collect::>>() }) .collect() } @@ -264,14 +264,14 @@ pub fn cnf_rewrite(expr: Expr) -> Expr { println!("AAL input:\n\n{}", expr); // Find all exprs joined by OR - let disjuncts = split_binary_owned(expr, Operator::Or); + let disjuncts = split_binary(&expr, Operator::Or); println!("AAL disjuncts:\n\n{:#?}", disjuncts); // For each expr, find joined by AND // A OR B OR C --> split each A, B and C - let disjunct_conjuncts: VecDeque> = disjuncts + let disjunct_conjuncts: VecDeque> = disjuncts .into_iter() - .map(|e| split_binary_owned(e, Operator::And)) + .map(|e| split_binary(e, Operator::And)) .collect::>(); // now we want to distribute the item @@ -285,23 +285,22 @@ pub fn cnf_rewrite(expr: Expr) -> Expr { // chosen to avoid creating huge predicates let total_permutations = disjunct_conjuncts .iter() - .fold(1, |sz, exprs| sz * exprs.len()); + .fold(1usize, |sz, exprs| sz.saturating_mul(exprs.len())); + + println!("Total permutations: {total_permutations}"); - if total_permutations < 10 { + if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1) && total_permutations < 10 { + println!("Rewriting..."); // form the OR clauses( A OR B OR C ..) let or_clauses = permutations(disjunct_conjuncts) .into_iter() - .map(|exprs| disjunction(exprs).unwrap()); + .map(|exprs| disjunction(exprs.into_iter().cloned()).unwrap()); conjunction(or_clauses).unwrap() } - // otherwise reassemble the expression + // otherwise return the original expression else { - disjunction( - disjunct_conjuncts - .into_iter() - .map(|exprs| conjunction(exprs).unwrap()), - ) - .unwrap() + println!("returning original..."); + expr } } @@ -941,21 +940,16 @@ mod tests { ) } - // fn cnf_rewrite(expr: Expr) -> Expr { - // let mut helper = CnfHelper::new(); - // expr.rewrite(&mut helper).unwrap() - // } - #[test] fn test_permutations() { - assert_eq!(permutations(vec![].into()), vec![] as Vec>) + assert_eq!(make_permutations(vec![]), vec![] as Vec>) } #[test] fn test_permutations_one() { // [[a]] --> [[a]] assert_eq!( - permutations(vec![vec![col("a")]].into()), + make_permutations(vec![vec![col("a")]]), vec![vec![col("a")]] ) } @@ -964,7 +958,7 @@ mod tests { fn test_permutations_two() { // [[a, b]] --> [[a], [b]] assert_eq!( - permutations(vec![vec![col("a"), col("b")]].into()), + make_permutations(vec![vec![col("a"), col("b")]]), vec![vec![col("a")], vec![col("b")]] ) } @@ -973,7 +967,7 @@ mod tests { fn test_permutations_two_and_one() { // [[a, b], [c]] --> [[a, c], [b, c]] assert_eq!( - permutations(vec![vec![col("a"), col("b")], vec![col("c")]].into()), + make_permutations(vec![vec![col("a"), col("b")], vec![col("c")]]), vec![vec![col("a"), col("c")], vec![col("b"), col("c")]] ) } @@ -982,14 +976,11 @@ mod tests { fn test_permutations_two_and_one_and_two() { // [[a, b], [c], [d, e]] --> [[a, c, d], [a, c, e], [b, c, d], [b, c, e]] assert_eq!( - permutations( - vec![ - vec![col("a"), col("b")], - vec![col("c")], - vec![col("d"), col("e")] - ] - .into() - ), + make_permutations(vec![ + vec![col("a"), col("b")], + vec![col("c")], + vec![col("d"), col("e")] + ]), vec![ vec![col("a"), col("c"), col("d")], vec![col("a"), col("c"), col("e")], @@ -999,6 +990,22 @@ mod tests { ) } + /// call permutations with owned `Expr`s + fn make_permutations(exprs: impl IntoIterator>) -> Vec> { + let exprs = exprs.into_iter().collect::>(); + + let exprs: VecDeque> = exprs + .iter() + .map(|exprs| exprs.iter().collect::>()) + .collect(); + + permutations(exprs) + .into_iter() + // copy &Expr --> Expr + .map(|exprs| exprs.into_iter().cloned().collect()) + .collect() + } + #[test] fn test_rewrite_cnf() { let a_1 = col("a").eq(lit(1i64)); From 16f1c1e30836853af12be918a21cf8385e3f91e8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 11:33:27 -0400 Subject: [PATCH 11/14] cleanup --- datafusion/optimizer/src/filter_push_down.rs | 18 +-- datafusion/optimizer/src/utils.rs | 138 +------------------ 2 files changed, 9 insertions(+), 147 deletions(-) diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 4276138ef93d..1dfc791c68b1 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -14,7 +14,6 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan -use crate::utils::{split_binary_owned, CnfHelper}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ @@ -29,7 +28,6 @@ use datafusion_expr::{ utils::{expr_to_columns, exprlist_to_columns, from_plan}, Expr, Operator, TableProviderFilterPushDown, }; -use log::error; use std::collections::{HashMap, HashSet}; use std::iter::once; @@ -541,21 +539,13 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let predicates = if state.use_cnf_rewrite { - let filter_cnf = - filter.predicate().clone().rewrite(&mut CnfHelper::new()); - match filter_cnf { - Ok(ref expr) => split_binary_owned(expr.clone(), Operator::And), - Err(e) => { - error!("Fail at CnfHelper rewrite: {}.", e); - split_binary_owned(filter.predicate().clone(), Operator::And) - } - } + let predicate = if state.use_cnf_rewrite { + utils::cnf_rewrite(filter.predicate().clone()) } else { - split_binary_owned(filter.predicate().clone(), Operator::And) + filter.predicate().clone() }; - predicates + utils::split_conjunction_owned(predicate) .into_iter() .try_for_each::<_, Result<()>>(|predicate| { let mut columns: HashSet = HashSet::new(); diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 38d1874e348f..d605235eb27e 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use datafusion_expr::{ and, col, @@ -171,41 +171,9 @@ fn split_binary_impl<'a>( } } -pub struct CnfHelper { - max_count: usize, - current_count: usize, - exprs: Vec, - original_expr: Option, -} - -impl CnfHelper { - pub fn new() -> Self { - CnfHelper { - max_count: 50, - current_count: 0, - exprs: vec![], - original_expr: None, - } - } - - pub fn new_with_max_count(max_count: usize) -> Self { - CnfHelper { - max_count, - current_count: 0, - exprs: vec![], - original_expr: None, - } - } - - fn increment_and_check_overload(&mut self) -> bool { - self.current_count += 1; - self.current_count >= self.max_count - } -} - -// Given some number of lists of Exprs, returns a set of expressions -// where there is one expression from each that there is a single -// element from each of the +/// Given some number of lists of Exprs, returns a set of expressions +/// where there is one expression from each that there is a single +/// element from each of the fn permutations(mut exprs: VecDeque>) -> Vec> { let first = if let Some(first) = exprs.pop_front() { first @@ -261,36 +229,23 @@ fn permutations(mut exprs: VecDeque>) -> Vec> { /// assert_eq!(expect, cnf_rewrite(expr)); /// ``` pub fn cnf_rewrite(expr: Expr) -> Expr { - println!("AAL input:\n\n{}", expr); - // Find all exprs joined by OR let disjuncts = split_binary(&expr, Operator::Or); - println!("AAL disjuncts:\n\n{:#?}", disjuncts); - // For each expr, find joined by AND + // For each expr, split now on AND // A OR B OR C --> split each A, B and C let disjunct_conjuncts: VecDeque> = disjuncts .into_iter() .map(|e| split_binary(e, Operator::And)) .collect::>(); - // now we want to distribute the item - // A AND (B OR C) - // --> - // (A OR B) AND (A OR C) - - println!("AAL disjunct conjuncts:\n\n{:#?}", disjunct_conjuncts); - // Decide if we want to distribute the clauses. Heuristic is // chosen to avoid creating huge predicates let total_permutations = disjunct_conjuncts .iter() .fold(1usize, |sz, exprs| sz.saturating_mul(exprs.len())); - println!("Total permutations: {total_permutations}"); - if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1) && total_permutations < 10 { - println!("Rewriting..."); // form the OR clauses( A OR B OR C ..) let or_clauses = permutations(disjunct_conjuncts) .into_iter() @@ -299,93 +254,10 @@ pub fn cnf_rewrite(expr: Expr) -> Expr { } // otherwise return the original expression else { - println!("returning original..."); expr } } -impl ExprRewriter for CnfHelper { - fn pre_visit(&mut self, expr: &Expr) -> Result { - let is_root = self.original_expr.is_none(); - if is_root { - self.original_expr = Some(expr.clone()); - } - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match op { - Operator::And => { - if self.increment_and_check_overload() { - return Ok(RewriteRecursion::Mutate); - } - } - // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) - Operator::Or => { - // Avoid create to much Expr like in tpch q19. - let lc = split_binary(left, Operator::Or) - .into_iter() - .flat_map(|e| split_binary(e, Operator::And)) - .count(); - let rc = split_binary(right, Operator::Or) - .into_iter() - .flat_map(|e| split_binary(e, Operator::And)) - .count(); - self.current_count += lc * rc - 1; - if self.increment_and_check_overload() { - return Ok(RewriteRecursion::Mutate); - } - let left_and_split = split_binary(left, Operator::And); - let right_and_split = split_binary(right, Operator::And); - for l in left_and_split { - for &r in &right_and_split { - self.exprs.push(Expr::BinaryExpr(BinaryExpr { - left: Box::new(l.clone()), - op: Operator::Or, - right: Box::new(r.clone()), - })); - } - } - return Ok(RewriteRecursion::Mutate); - } - _ => { - if self.increment_and_check_overload() { - return Ok(RewriteRecursion::Mutate); - } - self.exprs.push(expr.clone()); - return Ok(RewriteRecursion::Stop); - } - } - } - other => { - if self.increment_and_check_overload() { - return Ok(RewriteRecursion::Mutate); - } - self.exprs.push(other.clone()); - return Ok(RewriteRecursion::Stop); - } - } - if is_root { - Ok(RewriteRecursion::Continue) - } else { - Ok(RewriteRecursion::Skip) - } - } - - fn mutate(&mut self, _expr: Expr) -> Result { - if self.current_count >= self.max_count { - Ok(self.original_expr.as_ref().unwrap().clone()) - } else { - Ok(conjunction(self.exprs.clone()) - .unwrap_or_else(|| self.original_expr.take().unwrap())) - } - } -} - -impl Default for CnfHelper { - fn default() -> Self { - Self::new() - } -} - /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. From 76e91a7cfe8025753c82d499e2b609091f23fca2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 11:43:07 -0400 Subject: [PATCH 12/14] cleanups --- datafusion/optimizer/src/filter_push_down.rs | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 1dfc791c68b1..360044d04d48 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -70,7 +70,6 @@ type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); struct State { // (predicate, columns on the predicate) filters: Vec, - use_cnf_rewrite: bool, } impl State { @@ -81,14 +80,6 @@ impl State { .zip(predicates.1) .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone()))) } - - // set `true` means split the filter-exprs into CNF (see `CnfHelper`) - // to push more filter conditions down, it may cause filter-exprs size - // expansion. - fn with_cnf_rewrite(mut self) -> Self { - self.use_cnf_rewrite = true; - self - } } /// returns all predicates in `state` that depend on any of `used_columns` @@ -539,11 +530,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let predicate = if state.use_cnf_rewrite { - utils::cnf_rewrite(filter.predicate().clone()) - } else { - filter.predicate().clone() - }; + let predicate = utils::cnf_rewrite(filter.predicate().clone()); utils::split_conjunction_owned(predicate) .into_iter() @@ -810,7 +797,7 @@ impl OptimizerRule for FilterPushDown { plan: &LogicalPlan, _: &mut OptimizerConfig, ) -> Result { - optimize(plan, State::default().with_cnf_rewrite()) + optimize(plan, State::default()) } } From 50d066480aca205f219a26b03c1b46faec1a1ae0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 11:50:41 -0400 Subject: [PATCH 13/14] Clean up docstrings --- datafusion/optimizer/src/utils.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index d605235eb27e..6329593658eb 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -171,9 +171,12 @@ fn split_binary_impl<'a>( } } -/// Given some number of lists of Exprs, returns a set of expressions -/// where there is one expression from each that there is a single -/// element from each of the +/// Given a list of lists of [`Expr`]s, returns a list of lists of +/// [`Expr`]s of expressions where there is one expression from each +/// from each of the input expressions +/// +/// For example, given the input `[[a, b], [c], [d, e]]` returns +/// `[a, c, d], [a, c, e], [b, c, d], [b, c, e]]`. fn permutations(mut exprs: VecDeque>) -> Vec> { let first = if let Some(first) = exprs.pop_front() { first @@ -202,7 +205,12 @@ fn permutations(mut exprs: VecDeque>) -> Vec> { } } -/// Converts an expression to conjunctive normal form (CNF). +const MAX_CNF_REWRITE_CONJUNCTS: usize = 10; + +/// Tries to convert an expression to conjunctive normal form (CNF). +/// +/// Does not convert the expression if the total number of conjuncts +/// (exprs ANDed together) would exceed [`MAX_CNF_REWRITE_CONJUNCTS`]. /// /// The following expression is in CNF: /// `(a OR b) AND (c OR d)` @@ -213,6 +221,7 @@ fn permutations(mut exprs: VecDeque>) -> Vec> { /// But could be rewrite to a CNF expression: /// `(a OR c) AND (b OR c)`. /// +/// /// # Example /// ``` /// # use datafusion_expr::{col, lit}; @@ -241,14 +250,16 @@ pub fn cnf_rewrite(expr: Expr) -> Expr { // Decide if we want to distribute the clauses. Heuristic is // chosen to avoid creating huge predicates - let total_permutations = disjunct_conjuncts + let num_conjuncts = disjunct_conjuncts .iter() .fold(1usize, |sz, exprs| sz.saturating_mul(exprs.len())); - if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1) && total_permutations < 10 { - // form the OR clauses( A OR B OR C ..) + if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1) + && num_conjuncts < MAX_CNF_REWRITE_CONJUNCTS + { let or_clauses = permutations(disjunct_conjuncts) .into_iter() + // form the OR clauses( A OR B OR C ..) .map(|exprs| disjunction(exprs.into_iter().cloned()).unwrap()); conjunction(or_clauses).unwrap() } @@ -862,7 +873,7 @@ mod tests { ) } - /// call permutations with owned `Expr`s + /// call permutations with owned `Expr`s for easier testing fn make_permutations(exprs: impl IntoIterator>) -> Vec> { let exprs = exprs.into_iter().collect::>(); From 6dd1ce32fd3304758582fcc9e4587fa401ba9f52 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 26 Oct 2022 12:00:37 -0400 Subject: [PATCH 14/14] Clippy, restore missing test --- datafusion/optimizer/src/utils.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6329593658eb..c5496b5237f4 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -189,13 +189,13 @@ fn permutations(mut exprs: VecDeque>) -> Vec> { first.into_iter().map(|e| vec![e]).collect() } else { first - .iter() + .into_iter() .flat_map(|expr| { permutations(exprs.clone()) .into_iter() .map(|expr_list| { // Create [expr, ...] for each permutation - std::iter::once(expr.clone()) + std::iter::once(expr) .chain(expr_list.into_iter()) .collect::>() }) @@ -698,7 +698,7 @@ mod tests { #[test] fn test_split_conjunction_owned_alias() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias")),), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), vec![ col("a").eq(lit(5)), // no alias on b @@ -741,6 +741,12 @@ mod tests { assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); } + #[test] + fn test_split_conjunction_owned_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + #[test] fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new();