-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimizer: add framework for the rule of pre-add cast to the literal in comparison binary #3185
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ use datafusion::physical_plan::{ | |
}; | ||
use datafusion::prelude::*; | ||
use datafusion::scalar::ScalarValue; | ||
use std::ops::Deref; | ||
use std::sync::Arc; | ||
|
||
fn create_batch(value: i32, num_rows: usize) -> Result<RecordBatch> { | ||
|
@@ -146,7 +147,20 @@ impl TableProvider for CustomProvider { | |
match &filters[0] { | ||
Expr::BinaryExpr { right, .. } => { | ||
let int_value = match &**right { | ||
Expr::Literal(ScalarValue::Int8(i)) => i.unwrap() as i64, | ||
Expr::Literal(ScalarValue::Int16(i)) => i.unwrap() as i64, | ||
Expr::Literal(ScalarValue::Int32(i)) => i.unwrap() as i64, | ||
Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(), | ||
Expr::Cast { expr, data_type: _ } => match expr.deref() { | ||
Expr::Literal(lit_value) => match lit_value { | ||
ScalarValue::Int8(v) => v.unwrap() as i64, | ||
ScalarValue::Int16(v) => v.unwrap() as i64, | ||
ScalarValue::Int32(v) => v.unwrap() as i64, | ||
ScalarValue::Int64(v) => v.unwrap(), | ||
_ => unimplemented!(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method returns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love it! I'm very excited to see datafusion doing this for new code 😃 |
||
}, | ||
_ => unimplemented!(), | ||
}, | ||
_ => unimplemented!(), | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -271,8 +271,8 @@ async fn csv_explain_plans() { | |
let expected = vec![ | ||
"Explain [plan_type:Utf8, plan:Utf8]", | ||
" Projection: #aggregate_test_100.c1 [c1:Utf8]", | ||
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", | ||
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", | ||
" Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. after optimization, the INT64(10) will be cast to |
||
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]", | ||
]; | ||
let formatted = plan.display_indent_schema().to_string(); | ||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||
|
@@ -286,8 +286,8 @@ async fn csv_explain_plans() { | |
let expected = vec![ | ||
"Explain", | ||
" Projection: #aggregate_test_100.c1", | ||
" Filter: #aggregate_test_100.c2 > Int64(10)", | ||
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]", | ||
" Filter: #aggregate_test_100.c2 > Int32(10)", | ||
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]", | ||
]; | ||
let formatted = plan.display_indent().to_string(); | ||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||
|
@@ -307,9 +307,9 @@ async fn csv_explain_plans() { | |
" 2[shape=box label=\"Explain\"]", | ||
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", | ||
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", | ||
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]", | ||
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]", | ||
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]", | ||
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" }", | ||
" subgraph cluster_6", | ||
|
@@ -318,9 +318,9 @@ async fn csv_explain_plans() { | |
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", | ||
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", | ||
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", | ||
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", | ||
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", | ||
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", | ||
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" }", | ||
"}", | ||
|
@@ -349,7 +349,7 @@ async fn csv_explain_plans() { | |
// Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content | ||
assert_contains!(&actual, "logical_plan"); | ||
assert_contains!(&actual, "Projection: #aggregate_test_100.c1"); | ||
assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)"); | ||
assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)"); | ||
} | ||
|
||
#[tokio::test] | ||
|
@@ -469,8 +469,8 @@ async fn csv_explain_verbose_plans() { | |
let expected = vec![ | ||
"Explain [plan_type:Utf8, plan:Utf8]", | ||
" Projection: #aggregate_test_100.c1 [c1:Utf8]", | ||
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", | ||
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", | ||
" Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]", | ||
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]", | ||
]; | ||
let formatted = plan.display_indent_schema().to_string(); | ||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||
|
@@ -484,8 +484,8 @@ async fn csv_explain_verbose_plans() { | |
let expected = vec![ | ||
"Explain", | ||
" Projection: #aggregate_test_100.c1", | ||
" Filter: #aggregate_test_100.c2 > Int64(10)", | ||
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]", | ||
" Filter: #aggregate_test_100.c2 > Int32(10)", | ||
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]", | ||
]; | ||
let formatted = plan.display_indent().to_string(); | ||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||
|
@@ -505,9 +505,9 @@ async fn csv_explain_verbose_plans() { | |
" 2[shape=box label=\"Explain\"]", | ||
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", | ||
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", | ||
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]", | ||
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]", | ||
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]", | ||
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" }", | ||
" subgraph cluster_6", | ||
|
@@ -516,9 +516,9 @@ async fn csv_explain_verbose_plans() { | |
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", | ||
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", | ||
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", | ||
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", | ||
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", | ||
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", | ||
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", | ||
" }", | ||
"}", | ||
|
@@ -549,7 +549,7 @@ async fn csv_explain_verbose_plans() { | |
// important content | ||
assert_contains!(&actual, "logical_plan after projection_push_down"); | ||
assert_contains!(&actual, "physical_plan"); | ||
assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10"); | ||
assert_contains!(&actual, "FilterExec: c2@1 > 10"); | ||
assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); | ||
} | ||
|
||
|
@@ -745,7 +745,7 @@ async fn csv_explain() { | |
// then execute the physical plan and return the final explain results | ||
let ctx = SessionContext::new(); | ||
register_aggregate_csv_by_sql(&ctx).await; | ||
let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; | ||
let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > cast(10 as int)"; | ||
let actual = execute(&ctx, sql).await; | ||
let actual = normalize_vec_for_explain(actual); | ||
|
||
|
@@ -755,13 +755,13 @@ async fn csv_explain() { | |
vec![ | ||
"logical_plan", | ||
"Projection: #aggregate_test_100.c1\ | ||
\n Filter: #aggregate_test_100.c2 > Int64(10)\ | ||
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]" | ||
\n Filter: #aggregate_test_100.c2 > Int32(10)\ | ||
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]" | ||
], | ||
vec!["physical_plan", | ||
"ProjectionExec: expr=[c1@0 as c1]\ | ||
\n CoalesceBatchesExec: target_batch_size=4096\ | ||
\n FilterExec: CAST(c2@1 AS Int64) > 10\ | ||
\n FilterExec: c2@1 > 10\ | ||
\n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ | ||
\n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ | ||
\n" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might want do avoid doing this for NULLs (aka
None
) valuesSomething like:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for you comments, I just follow the original implementation.
But I will follow your nice comments.