Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST,
Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]]
Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping
Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping
Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE")
Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY")
Inner Join: customer.c_nationkey = n2.n_nationkey
Inner Join: supplier.s_nationkey = n1.n_nationkey
Inner Join: orders.o_custkey = customer.c_custkey
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_plan/file_format/row_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow::record_batch::RecordBatch;
use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};

use datafusion_expr::Expr;
use datafusion_expr::{Expr, Operator};
use datafusion_optimizer::utils::split_conjunction_owned;
use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
Expand Down Expand Up @@ -253,7 +253,7 @@ pub fn build_row_filter(
metadata: &ParquetMetaData,
reorder_predicates: bool,
) -> Result<Option<RowFilter>> {
let predicates = split_conjunction_owned(expr);
let predicates = split_conjunction_owned(expr, Operator::And);

let mut candidates: Vec<FilterCandidate> = predicates
.into_iter()
Expand Down
7 changes: 6 additions & 1 deletion datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1468,10 +1468,15 @@ async fn reduce_left_join_2() -> Result<()> {
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

// filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')`
// could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)`
// the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter.

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
Expand Down
37 changes: 35 additions & 2 deletions datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan

use crate::utils::{split_conjunction, CnfHelper};
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::{
Expand All @@ -28,6 +29,7 @@ use datafusion_expr::{
utils::{expr_to_columns, exprlist_to_columns, from_plan},
Expr, Operator, TableProviderFilterPushDown,
};
use log::error;
use std::collections::{HashMap, HashSet};
use std::iter::once;

Expand Down Expand Up @@ -530,7 +532,14 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
}
LogicalPlan::Analyze { .. } => push_down(&state, plan),
LogicalPlan::Filter(filter) => {
let predicates = utils::split_conjunction(filter.predicate());
let filter_cnf = filter.predicate().clone().rewrite(&mut CnfHelper::new());
let predicates = match filter_cnf {
Ok(ref expr) => split_conjunction(expr),
Err(e) => {
error!("Fail at CnfHelper rewrite: {}.", e);
split_conjunction(filter.predicate())
}
};

predicates
.into_iter()
Expand Down Expand Up @@ -953,6 +962,30 @@ mod tests {
Ok(())
}

#[test]
fn filter_keep_partial_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
let filter = f1.or(f2);
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(filter)?
.build()?;
// filter of aggregate is after aggregation since they are non-commutative
// (c =1 AND b > 2) OR (c = 1 AND b > 3)
// rewrite to CNF
// (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3)

let expected = "\
Copy link
Member Author

@Ted-Jiang Ted-Jiang Oct 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used to get not pushed:

Filter: test.c = Int64(1) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(3)
  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]
    TableScan: test

Copy link
Member Author

@Ted-Jiang Ted-Jiang Oct 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb @andygrove @Dandandan Do you think this is a reasonable improvement 🤔?
If this ok, i will modify the another test😂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good improvement for sure -- thank you

Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n Filter: test.c = Int64(1) OR test.c = Int64(1)\
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere need simplify this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a rule in https://github.com/apache/arrow-datafusion/blob/37fe938261636bb7710b26cf26e2bbd010e1dbf0/datafusion/optimizer/src/simplify_expressions.rs#L264

That pass gets run several times

I recommend we file a follow on ticket for that particular rewrite (the same thing applies to a AND a in addition to a OR a

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #3895

\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}

/// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
#[test]
fn alias() -> Result<()> {
Expand Down Expand Up @@ -2344,7 +2377,7 @@ mod tests {
.filter(filter)?
.build()?;

let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\
let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\
Expand Down
Loading