Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST,
Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]]
Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping
Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping
Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE")
Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY")
Inner Join: customer.c_nationkey = n2.n_nationkey
Inner Join: supplier.s_nationkey = n1.n_nationkey
Inner Join: orders.o_custkey = customer.c_custkey
Expand Down
7 changes: 6 additions & 1 deletion datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1468,10 +1468,15 @@ async fn reduce_left_join_2() -> Result<()> {
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

// filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')`
// could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)`
// the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter.

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
Expand Down
58 changes: 51 additions & 7 deletions datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan

use crate::utils::{split_binary_owned, CnfHelper};
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::{
Expand All @@ -28,6 +29,7 @@ use datafusion_expr::{
utils::{expr_to_columns, exprlist_to_columns, from_plan},
Expr, Operator, TableProviderFilterPushDown,
};
use log::error;
use std::collections::{HashMap, HashSet};
use std::iter::once;

Expand Down Expand Up @@ -70,6 +72,7 @@ type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<Column>>);
struct State {
// (predicate, columns on the predicate)
filters: Vec<Predicate>,
use_cnf_rewrite: bool,
}

impl State {
Expand All @@ -80,6 +83,11 @@ impl State {
.zip(predicates.1)
.for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone())))
}

fn with_cnf_rewrite(mut self) -> Self {
self.use_cnf_rewrite = true;
self
}
}

/// returns all predicates in `state` that depend on any of `used_columns`
Expand Down Expand Up @@ -457,7 +465,7 @@ fn optimize_join(
// Build new filter states using pushable predicates
// from current optimizer states and from ON clause.
// Then recursively call optimization for both join inputs
let mut left_state = State { filters: vec![] };
let mut left_state = State::default();
left_state.append_predicates(to_left);
left_state.append_predicates(on_to_left);
or_to_left
Expand All @@ -472,7 +480,7 @@ fn optimize_join(
.for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
let left = optimize(left, left_state)?;

let mut right_state = State { filters: vec![] };
let mut right_state = State::default();
right_state.append_predicates(to_right);
right_state.append_predicates(on_to_right);
or_to_right
Expand Down Expand Up @@ -530,14 +538,26 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
}
LogicalPlan::Analyze { .. } => push_down(&state, plan),
LogicalPlan::Filter(filter) => {
let predicates = utils::split_conjunction(filter.predicate());
let predicates = if state.use_cnf_rewrite {
let filter_cnf =
filter.predicate().clone().rewrite(&mut CnfHelper::new());
match filter_cnf {
Ok(ref expr) => split_binary_owned(expr.clone(), Operator::And),
Err(e) => {
error!("Fail at CnfHelper rewrite: {}.", e);
split_binary_owned(filter.predicate().clone(), Operator::And)
}
}
} else {
split_binary_owned(filter.predicate().clone(), Operator::And)
};

predicates
.into_iter()
.try_for_each::<_, Result<()>>(|predicate| {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(predicate, &mut columns)?;
state.filters.push((predicate.clone(), columns));
expr_to_columns(&predicate, &mut columns)?;
state.filters.push((predicate, columns));
Ok(())
})?;

Expand Down Expand Up @@ -797,7 +817,7 @@ impl OptimizerRule for FilterPushDown {
plan: &LogicalPlan,
_: &mut OptimizerConfig,
) -> Result<LogicalPlan> {
optimize(plan, State::default())
optimize(plan, State::default().with_cnf_rewrite())
}
}

Expand Down Expand Up @@ -953,6 +973,30 @@ mod tests {
Ok(())
}

#[test]
fn filter_keep_partial_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
let filter = f1.or(f2);
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(filter)?
.build()?;
// filter of aggregate is after aggregation since they are non-commutative
// (c =1 AND b > 2) OR (c = 1 AND b > 3)
// rewrite to CNF
// (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3)

let expected = "\
Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n Filter: test.c = Int64(1) OR test.c = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}

/// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
#[test]
fn alias() -> Result<()> {
Expand Down Expand Up @@ -2344,7 +2388,7 @@ mod tests {
.filter(filter)?
.build()?;

let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\
let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\
Expand Down
Loading