diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index a1d1806f9189..fad02c09881c 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping - Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") + Filter: (n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE")) AND (n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY")) Inner Join: customer.c_nationkey = n2.n_nationkey Inner Join: supplier.s_nationkey = n1.n_nationkey Inner Join: orders.o_custkey = customer.c_custkey diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 2ff4947b3214..1ba8cf7ac42e 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1468,10 +1468,15 @@ async fn reduce_left_join_2() -> Result<()> { .expect(&msg); let state = ctx.state(); let plan = state.optimize(&plan)?; + + // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` + // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` + // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 148ae6715ddb..360044d04d48 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -457,7 +457,7 @@ fn optimize_join( // Build new filter states using pushable predicates // from current optimizer states and from ON clause. // Then recursively call optimization for both join inputs - let mut left_state = State { filters: vec![] }; + let mut left_state = State::default(); left_state.append_predicates(to_left); left_state.append_predicates(on_to_left); or_to_left @@ -472,7 +472,7 @@ fn optimize_join( .for_each(|(expr, cols)| left_state.filters.push((expr, cols))); let left = optimize(left, left_state)?; - let mut right_state = State { filters: vec![] }; + let mut right_state = State::default(); right_state.append_predicates(to_right); right_state.append_predicates(on_to_right); or_to_right @@ -530,14 +530,14 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let predicates = utils::split_conjunction(filter.predicate()); + let predicate = utils::cnf_rewrite(filter.predicate().clone()); - predicates + utils::split_conjunction_owned(predicate) .into_iter() .try_for_each::<_, Result<()>>(|predicate| { let mut columns: HashSet = HashSet::new(); - expr_to_columns(predicate, &mut columns)?; - state.filters.push((predicate.clone(), columns)); + expr_to_columns(&predicate, &mut columns)?; + state.filters.push((predicate, columns)); Ok(()) })?; @@ -953,6 +953,30 @@ mod tests { Ok(()) } + #[test] + fn filter_keep_partial_agg() -> Result<()> { + let table_scan = test_table_scan()?; + let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); + let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); + let filter = f1.or(f2); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? + .filter(filter)? + .build()?; + // filter of aggregate is after aggregation since they are non-commutative + // (c =1 AND b > 2) OR (c = 1 AND b > 3) + // rewrite to CNF + // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) + + let expected = "\ + Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ + \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> { @@ -2344,13 +2368,14 @@ mod tests { .filter(filter)? .build()?; - let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ - \n CrossJoin:\ - \n Projection: test.a, test.b, test.c\ - \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ - \n TableScan: test\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; + let expected = "\ + Filter: (test.a = d OR test.b = e) AND (test.a = d OR test.c < UInt32(10)) AND (test.b > UInt32(1) OR test.b = e)\ + \n CrossJoin:\ + \n Projection: test.a, test.b, test.c\ + \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ + \n TableScan: test\ + \n Projection: test1.a AS d, test1.a AS e\ + \n TableScan: test1"; assert_optimized_plan_eq(&plan, expected); Ok(()) } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 4eda6e3e3cc2..c5496b5237f4 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -29,7 +29,7 @@ use datafusion_expr::{ utils::from_plan, Expr, Operator, }; -use std::collections::HashSet; +use std::collections::{HashSet, VecDeque}; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -99,20 +99,46 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// assert_eq!(split_conjunction_owned(expr), split); /// ``` pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_conjunction_owned_impl(expr, vec![]) + split_binary_owned(expr, Operator::And) } -fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { +/// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// This is often used to "split" expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_optimizer::utils::split_binary_owned; +/// # use std::ops::Add; +/// // a=1 + b=2 +/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_binary_owned to split them +/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// ``` +pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { + split_binary_owned_impl(expr, op, vec![]) +} + +fn split_binary_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_owned_impl(*left, exprs); - split_conjunction_owned_impl(*right, exprs) + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_binary_owned_impl(*left, operator, exprs); + split_binary_owned_impl(*right, operator, exprs) } - Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs), + Expr::Alias(expr, _) => split_binary_owned_impl(*expr, operator, exprs), other => { exprs.push(other); exprs @@ -120,6 +146,129 @@ fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { } } +/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// See [`split_binary_owned`] for more details and an example. +pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { + split_binary_impl(expr, op, vec![]) +} + +fn split_binary_impl<'a>( + expr: &'a Expr, + operator: Operator, + mut exprs: Vec<&'a Expr>, +) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + let exprs = split_binary_impl(left, operator, exprs); + split_binary_impl(right, operator, exprs) + } + Expr::Alias(expr, _) => split_binary_impl(expr, operator, exprs), + other => { + exprs.push(other); + exprs + } + } +} + +/// Given a list of lists of [`Expr`]s, returns a list of lists of +/// [`Expr`]s of expressions where there is one expression from each +/// from each of the input expressions +/// +/// For example, given the input `[[a, b], [c], [d, e]]` returns +/// `[a, c, d], [a, c, e], [b, c, d], [b, c, e]]`. +fn permutations(mut exprs: VecDeque>) -> Vec> { + let first = if let Some(first) = exprs.pop_front() { + first + } else { + return vec![]; + }; + + // base case: + if exprs.is_empty() { + first.into_iter().map(|e| vec![e]).collect() + } else { + first + .into_iter() + .flat_map(|expr| { + permutations(exprs.clone()) + .into_iter() + .map(|expr_list| { + // Create [expr, ...] for each permutation + std::iter::once(expr) + .chain(expr_list.into_iter()) + .collect::>() + }) + .collect::>>() + }) + .collect() + } +} + +const MAX_CNF_REWRITE_CONJUNCTS: usize = 10; + +/// Tries to convert an expression to conjunctive normal form (CNF). +/// +/// Does not convert the expression if the total number of conjuncts +/// (exprs ANDed together) would exceed [`MAX_CNF_REWRITE_CONJUNCTS`]. +/// +/// The following expression is in CNF: +/// `(a OR b) AND (c OR d)` +/// +/// The following is not in CNF: +/// `(a AND b) OR c`. +/// +/// But could be rewrite to a CNF expression: +/// `(a OR c) AND (b OR c)`. +/// +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_optimizer::utils::cnf_rewrite; +/// // (a=1 AND b=2)OR c = 3 +/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// let expr2 = col("c").eq(lit(3)); +/// let expr = expr1.or(expr2); +/// +/// //(a=1 or c=3)AND(b=2 or c=3) +/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); +/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); +/// let expect = expr1.and(expr2); +/// assert_eq!(expect, cnf_rewrite(expr)); +/// ``` +pub fn cnf_rewrite(expr: Expr) -> Expr { + // Find all exprs joined by OR + let disjuncts = split_binary(&expr, Operator::Or); + + // For each expr, split now on AND + // A OR B OR C --> split each A, B and C + let disjunct_conjuncts: VecDeque> = disjuncts + .into_iter() + .map(|e| split_binary(e, Operator::And)) + .collect::>(); + + // Decide if we want to distribute the clauses. Heuristic is + // chosen to avoid creating huge predicates + let num_conjuncts = disjunct_conjuncts + .iter() + .fold(1usize, |sz, exprs| sz.saturating_mul(exprs.len())); + + if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1) + && num_conjuncts < MAX_CNF_REWRITE_CONJUNCTS + { + let or_clauses = permutations(disjunct_conjuncts) + .into_iter() + // form the OR clauses( A OR B OR C ..) + .map(|exprs| disjunction(exprs.into_iter().cloned()).unwrap()); + conjunction(or_clauses).unwrap() + } + // otherwise return the original expression + else { + expr + } +} + /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. @@ -470,7 +619,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_expr::expr::Cast; - use datafusion_expr::{col, lit, utils::expr_to_columns}; + use datafusion_expr::{col, lit, or, utils::expr_to_columns}; use std::collections::HashSet; use std::ops::Add; @@ -508,6 +657,30 @@ mod tests { assert_eq!(result, vec![&expr]); } + #[test] + fn test_split_binary_owned() { + let expr = col("a"); + assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); + } + + #[test] + fn test_split_binary_owned_two() { + assert_eq!( + split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_binary_owned_different_op() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!( + // expr is connected by OR, but pass in AND + split_binary_owned(expr.clone(), Operator::And), + vec![expr] + ); + } + #[test] fn test_split_conjunction_owned() { let expr = col("a"); @@ -655,4 +828,135 @@ mod tests { "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" ) } + + #[test] + fn test_permutations() { + assert_eq!(make_permutations(vec![]), vec![] as Vec>) + } + + #[test] + fn test_permutations_one() { + // [[a]] --> [[a]] + assert_eq!( + make_permutations(vec![vec![col("a")]]), + vec![vec![col("a")]] + ) + } + + #[test] + fn test_permutations_two() { + // [[a, b]] --> [[a], [b]] + assert_eq!( + make_permutations(vec![vec![col("a"), col("b")]]), + vec![vec![col("a")], vec![col("b")]] + ) + } + + #[test] + fn test_permutations_two_and_one() { + // [[a, b], [c]] --> [[a, c], [b, c]] + assert_eq!( + make_permutations(vec![vec![col("a"), col("b")], vec![col("c")]]), + vec![vec![col("a"), col("c")], vec![col("b"), col("c")]] + ) + } + + #[test] + fn test_permutations_two_and_one_and_two() { + // [[a, b], [c], [d, e]] --> [[a, c, d], [a, c, e], [b, c, d], [b, c, e]] + assert_eq!( + make_permutations(vec![ + vec![col("a"), col("b")], + vec![col("c")], + vec![col("d"), col("e")] + ]), + vec![ + vec![col("a"), col("c"), col("d")], + vec![col("a"), col("c"), col("e")], + vec![col("b"), col("c"), col("d")], + vec![col("b"), col("c"), col("e")], + ] + ) + } + + /// call permutations with owned `Expr`s for easier testing + fn make_permutations(exprs: impl IntoIterator>) -> Vec> { + let exprs = exprs.into_iter().collect::>(); + + let exprs: VecDeque> = exprs + .iter() + .map(|exprs| exprs.iter().collect::>()) + .collect(); + + permutations(exprs) + .into_iter() + // copy &Expr --> Expr + .map(|exprs| exprs.into_iter().cloned().collect()) + .collect() + } + + #[test] + fn test_rewrite_cnf() { + let a_1 = col("a").eq(lit(1i64)); + let a_2 = col("a").eq(lit(2i64)); + + let b_1 = col("b").eq(lit(1i64)); + let b_2 = col("b").eq(lit(2i64)); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> not change + let expr1 = and(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + assert_eq!(expect, cnf_rewrite(expr1)); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1) + let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = and(a_1.clone(), b_2.clone()) + .and(a_2.clone()) + .and(b_1.clone()); + assert_eq!(expect, cnf_rewrite(expr1)); + + // Test rewrite on a1_or_b2 -> not change + let expr1 = or(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + assert_eq!(expect, cnf_rewrite(expr1)); + + // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1 + let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let a1_or_a2 = or(a_1.clone(), a_2.clone()); + let a1_or_b1 = or(a_1.clone(), b_1.clone()); + let b2_or_a2 = or(b_2.clone(), a_2.clone()); + let b2_or_b1 = or(b_2.clone(), b_1.clone()); + let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1); + assert_eq!(expect, cnf_rewrite(expr1)); + + // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1) + let a1_or_b2 = or(a_1.clone(), b_2.clone()); + let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone())); + assert_eq!(expect, cnf_rewrite(expr1)); + + // Test rewrite on a1_or_b2 or a2_or_b1 -> not change + let expr1 = or(or(a_1, b_2), or(a_2, b_1)); + let expect = expr1.clone(); + assert_eq!(expect, cnf_rewrite(expr1)); + } + + #[test] + fn test_rewrite_cnf_overflow() { + // in this situation: + // AND = (a=1 and b=2) + // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100 (a=1 or b=2) + // which cause size expansion. + + let mut expr1 = col("test1").eq(lit(1i64)); + let expr2 = col("test2").eq(lit(2i64)); + + for _i in 0..9 { + expr1 = expr1.clone().and(expr2.clone()); + } + let expr3 = expr1.clone(); + let expr = or(expr1, expr3); + + assert_eq!(expr, cnf_rewrite(expr.clone())); + } }