Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST,
Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]]
Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping
Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping
Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE")
Filter: (n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE")) AND (n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY"))
Copy link
Contributor

@isidentical isidentical Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also wanted to check TPC-H (it shouldn't affect, but just to see if there is an unexpected regression). It seems like there aren't any regressions (873.27 ms vs 878.54 ms, only noise) 🚀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@isidentical thanks for testing👍! I think this rewrite will make row_filter work, I think it will boost up query when row_filter is stable! 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking @isidentical -- I think there are only 25 rows in the NATION table in TPCH, so the ordering of predicates doesn't really matter for performance in that case 😆

Copy link
Member Author

@Ted-Jiang Ted-Jiang Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think there are only 25 rows in the NATION table in TPCH,

😂

Inner Join: customer.c_nationkey = n2.n_nationkey
Inner Join: supplier.s_nationkey = n1.n_nationkey
Inner Join: orders.o_custkey = customer.c_custkey
Expand Down
7 changes: 6 additions & 1 deletion datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1468,10 +1468,15 @@ async fn reduce_left_join_2() -> Result<()> {
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

// filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')`
// could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)`
// the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter.

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
Expand Down
51 changes: 38 additions & 13 deletions datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ fn optimize_join(
// Build new filter states using pushable predicates
// from current optimizer states and from ON clause.
// Then recursively call optimization for both join inputs
let mut left_state = State { filters: vec![] };
let mut left_state = State::default();
left_state.append_predicates(to_left);
left_state.append_predicates(on_to_left);
or_to_left
Expand All @@ -472,7 +472,7 @@ fn optimize_join(
.for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
let left = optimize(left, left_state)?;

let mut right_state = State { filters: vec![] };
let mut right_state = State::default();
right_state.append_predicates(to_right);
right_state.append_predicates(on_to_right);
or_to_right
Expand Down Expand Up @@ -530,14 +530,14 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
}
LogicalPlan::Analyze { .. } => push_down(&state, plan),
LogicalPlan::Filter(filter) => {
let predicates = utils::split_conjunction(filter.predicate());
let predicate = utils::cnf_rewrite(filter.predicate().clone());

predicates
utils::split_conjunction_owned(predicate)
.into_iter()
.try_for_each::<_, Result<()>>(|predicate| {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(predicate, &mut columns)?;
state.filters.push((predicate.clone(), columns));
expr_to_columns(&predicate, &mut columns)?;
state.filters.push((predicate, columns));
Ok(())
})?;

Expand Down Expand Up @@ -953,6 +953,30 @@ mod tests {
Ok(())
}

#[test]
fn filter_keep_partial_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
let filter = f1.or(f2);
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(filter)?
.build()?;
// filter of aggregate is after aggregation since they are non-commutative
// (c =1 AND b > 2) OR (c = 1 AND b > 3)
// rewrite to CNF
// (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3)

let expected = "\
Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n Filter: test.c = Int64(1) OR test.c = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}

/// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
#[test]
fn alias() -> Result<()> {
Expand Down Expand Up @@ -2344,13 +2368,14 @@ mod tests {
.filter(filter)?
.build()?;

let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\
\n TableScan: test\
\n Projection: test1.a AS d, test1.a AS e\
\n TableScan: test1";
let expected = "\
Filter: (test.a = d OR test.b = e) AND (test.a = d OR test.c < UInt32(10)) AND (test.b > UInt32(1) OR test.b = e)\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\
\n TableScan: test\
\n Projection: test1.a AS d, test1.a AS e\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
Expand Down
Loading