Skip to content

Commit

Permalink
feat: add optimize rule rewrite_disjunctive_predicate (#2858)
Browse files Browse the repository at this point in the history
* feat: add optimize rule: rewrite_disjunctive_predicate

* address comments and add tests

* Update datafusion/optimizer/src/rewrite_disjunctive_predicate.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
xudong963 and alamb committed Jul 26, 2022
1 parent 0f19990 commit 4005076
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 0 deletions.
2 changes: 2 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery
use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
Expand Down Expand Up @@ -1367,6 +1368,7 @@ impl SessionState {
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateLimit::new()),
Arc::new(ProjectionPushDown::new()),
Arc::new(RewriteDisjunctivePredicate::new()),
];
if config.config_options.get_bool(OPT_FILTER_NULL_JOIN_KEYS) {
rules.push(Arc::new(FilterNullJoinKeys::default()));
Expand Down
56 changes: 56 additions & 0 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,59 @@ async fn csv_in_set_test() -> Result<()> {
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn multiple_or_predicates() -> Result<()> {
let ctx = SessionContext::new();
register_tpch_csv(&ctx, "lineitem").await?;
register_tpch_csv(&ctx, "part").await?;
let sql = "explain select
l_partkey
from
lineitem,
part
where
(
p_partkey = l_partkey
and p_brand = 'Brand#12'
and l_quantity >= 1 and l_quantity <= 1 + 10
and p_size between 1 and 5
)
or
(
p_partkey = l_partkey
and p_brand = 'Brand#23'
and l_quantity >= 10 and l_quantity <= 10 + 10
and p_size between 1 and 10
)
or
(
p_partkey = l_partkey
and p_brand = 'Brand#34'
and l_quantity >= 20 and l_quantity <= 20 + 10
and p_size between 1 and 15
)";
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx.create_logical_plan(sql).expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;
// Note that we expect `#part.p_partkey = #lineitem.l_partkey` to have been
// factored out and appear only once in the following plan
let expected =vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND #lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int64(30) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
" TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);
Ok(())
}
1 change: 1 addition & 0 deletions datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby;
pub mod subquery_filter_to_join;
pub mod utils;

pub mod rewrite_disjunctive_predicate;
#[cfg(test)]
pub mod test;

Expand Down
Loading

0 comments on commit 4005076

Please sign in to comment.