Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer: add framework for the rule of pre-add cast to the literal in comparison binary #3185

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery
use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_sql::{
parser::DFParser,
Expand Down Expand Up @@ -1360,6 +1361,7 @@ impl SessionState {
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(DecorrelateScalarSubquery::new()),
Expand Down
34 changes: 32 additions & 2 deletions datafusion/core/tests/provider_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use datafusion::physical_plan::{
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use datafusion_common::DataFusionError;
use std::ops::Deref;
use std::sync::Arc;

fn create_batch(value: i32, num_rows: usize) -> Result<RecordBatch> {
Expand Down Expand Up @@ -146,8 +148,36 @@ impl TableProvider for CustomProvider {
match &filters[0] {
Expr::BinaryExpr { right, .. } => {
let int_value = match &**right {
Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(),
_ => unimplemented!(),
Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i))) => *i as i64,
Expr::Cast { expr, data_type: _ } => match expr.deref() {
Expr::Literal(lit_value) => match lit_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
ScalarValue::Int64(Some(v)) => *v,
other_value => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support value {:?}",
other_value
)))
}
},
other_expr => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support expr {:?}",
other_expr
)))
}
},
other_expr => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support expr {:?}",
other_expr
)))
}
};

Ok(Arc::new(CustomPlan {
Expand Down
44 changes: 22 additions & 22 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]",
" Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after optimization, the INT64(10) will be cast to INT32(10), because of the left type is INT32

" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -286,8 +286,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain",
" Projection: #aggregate_test_100.c1",
" Filter: #aggregate_test_100.c2 > Int64(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]",
" Filter: #aggregate_test_100.c2 > Int32(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
];
let formatted = plan.display_indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -307,9 +307,9 @@ async fn csv_explain_plans() {
" 2[shape=box label=\"Explain\"]",
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
" subgraph cluster_6",
Expand All @@ -318,9 +318,9 @@ async fn csv_explain_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
Expand Down Expand Up @@ -349,7 +349,7 @@ async fn csv_explain_plans() {
// Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content
assert_contains!(&actual, "logical_plan");
assert_contains!(&actual, "Projection: #aggregate_test_100.c1");
assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)");
assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)");
}

#[tokio::test]
Expand Down Expand Up @@ -469,8 +469,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]",
" Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -484,8 +484,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain",
" Projection: #aggregate_test_100.c1",
" Filter: #aggregate_test_100.c2 > Int64(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]",
" Filter: #aggregate_test_100.c2 > Int32(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
];
let formatted = plan.display_indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -505,9 +505,9 @@ async fn csv_explain_verbose_plans() {
" 2[shape=box label=\"Explain\"]",
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
" subgraph cluster_6",
Expand All @@ -516,9 +516,9 @@ async fn csv_explain_verbose_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
Expand Down Expand Up @@ -549,7 +549,7 @@ async fn csv_explain_verbose_plans() {
// important content
assert_contains!(&actual, "logical_plan after projection_push_down");
assert_contains!(&actual, "physical_plan");
assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10");
assert_contains!(&actual, "FilterExec: c2@1 > 10");
assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]");
}

Expand Down Expand Up @@ -745,7 +745,7 @@ async fn csv_explain() {
// then execute the physical plan and return the final explain results
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10";
let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > cast(10 as int)";
let actual = execute(&ctx, sql).await;
let actual = normalize_vec_for_explain(actual);

Expand All @@ -755,13 +755,13 @@ async fn csv_explain() {
vec![
"logical_plan",
"Projection: #aggregate_test_100.c1\
\n Filter: #aggregate_test_100.c2 > Int64(10)\
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]"
\n Filter: #aggregate_test_100.c2 > Int32(10)\
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]"
],
vec!["physical_plan",
"ProjectionExec: expr=[c1@0 as c1]\
\n CoalesceBatchesExec: target_batch_size=4096\
\n FilterExec: CAST(c2@1 AS Int64) > 10\
\n FilterExec: c2@1 > 10\
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\
\n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\
\n"
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
Inner Join: #part.p_partkey = #partsupp.ps_partkey
Filter: #part.p_size = Int64(15) AND #part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE Utf8("%BRASS")]
Filter: #part.p_size = Int32(15) AND #part.p_type LIKE Utf8("%BRASS")
TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int32(15), #part.p_type LIKE Utf8("%BRASS")]
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby;
pub mod subquery_filter_to_join;
pub mod utils;

pub mod pre_cast_lit_in_comparison;
pub mod rewrite_disjunctive_predicate;
#[cfg(test)]
pub mod test;
Expand Down