Skip to content

Commit

Permalink
optimizer: add framework for the rule of pre-add cast to the literal …
Browse files Browse the repository at this point in the history
…in comparison binary (#3185)

* add rule pre add cast to literal

* address comments and fix clippy

* change panic to result
  • Loading branch information
liukun4515 committed Aug 23, 2022
1 parent eedc787 commit 9ecf277
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 26 deletions.
2 changes: 2 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery
use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_sql::{
parser::DFParser,
Expand Down Expand Up @@ -1358,6 +1359,7 @@ impl SessionState {
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(DecorrelateScalarSubquery::new()),
Expand Down
34 changes: 32 additions & 2 deletions datafusion/core/tests/provider_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use datafusion::physical_plan::{
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use datafusion_common::DataFusionError;
use std::ops::Deref;
use std::sync::Arc;

fn create_batch(value: i32, num_rows: usize) -> Result<RecordBatch> {
Expand Down Expand Up @@ -146,8 +148,36 @@ impl TableProvider for CustomProvider {
match &filters[0] {
Expr::BinaryExpr { right, .. } => {
let int_value = match &**right {
Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(),
_ => unimplemented!(),
Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i))) => *i as i64,
Expr::Cast { expr, data_type: _ } => match expr.deref() {
Expr::Literal(lit_value) => match lit_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
ScalarValue::Int64(Some(v)) => *v,
other_value => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support value {:?}",
other_value
)))
}
},
other_expr => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support expr {:?}",
other_expr
)))
}
},
other_expr => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support expr {:?}",
other_expr
)))
}
};

Ok(Arc::new(CustomPlan {
Expand Down
44 changes: 22 additions & 22 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]",
" Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -286,8 +286,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain",
" Projection: #aggregate_test_100.c1",
" Filter: #aggregate_test_100.c2 > Int64(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]",
" Filter: #aggregate_test_100.c2 > Int32(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
];
let formatted = plan.display_indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -307,9 +307,9 @@ async fn csv_explain_plans() {
" 2[shape=box label=\"Explain\"]",
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
" subgraph cluster_6",
Expand All @@ -318,9 +318,9 @@ async fn csv_explain_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
Expand Down Expand Up @@ -349,7 +349,7 @@ async fn csv_explain_plans() {
// Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content
assert_contains!(&actual, "logical_plan");
assert_contains!(&actual, "Projection: #aggregate_test_100.c1");
assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)");
assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)");
}

#[tokio::test]
Expand Down Expand Up @@ -469,8 +469,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]",
" Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -484,8 +484,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain",
" Projection: #aggregate_test_100.c1",
" Filter: #aggregate_test_100.c2 > Int64(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]",
" Filter: #aggregate_test_100.c2 > Int32(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
];
let formatted = plan.display_indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -505,9 +505,9 @@ async fn csv_explain_verbose_plans() {
" 2[shape=box label=\"Explain\"]",
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
" subgraph cluster_6",
Expand All @@ -516,9 +516,9 @@ async fn csv_explain_verbose_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
Expand Down Expand Up @@ -549,7 +549,7 @@ async fn csv_explain_verbose_plans() {
// important content
assert_contains!(&actual, "logical_plan after projection_push_down");
assert_contains!(&actual, "physical_plan");
assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10");
assert_contains!(&actual, "FilterExec: c2@1 > 10");
assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]");
}

Expand Down Expand Up @@ -745,7 +745,7 @@ async fn csv_explain() {
// then execute the physical plan and return the final explain results
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10";
let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > cast(10 as int)";
let actual = execute(&ctx, sql).await;
let actual = normalize_vec_for_explain(actual);

Expand All @@ -755,13 +755,13 @@ async fn csv_explain() {
vec![
"logical_plan",
"Projection: #aggregate_test_100.c1\
\n Filter: #aggregate_test_100.c2 > Int64(10)\
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]"
\n Filter: #aggregate_test_100.c2 > Int32(10)\
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]"
],
vec!["physical_plan",
"ProjectionExec: expr=[c1@0 as c1]\
\n CoalesceBatchesExec: target_batch_size=4096\
\n FilterExec: CAST(c2@1 AS Int64) > 10\
\n FilterExec: c2@1 > 10\
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\
\n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\
\n"
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
Inner Join: #part.p_partkey = #partsupp.ps_partkey
Filter: #part.p_size = Int64(15) AND #part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE Utf8("%BRASS")]
Filter: #part.p_size = Int32(15) AND #part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int32(15), #part.p_type LIKE Utf8("%BRASS")]
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby;
pub mod subquery_filter_to_join;
pub mod utils;

pub mod pre_cast_lit_in_comparison;
pub mod rewrite_disjunctive_predicate;
#[cfg(test)]
pub mod test;
Expand Down

0 comments on commit 9ecf277

Please sign in to comment.