Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimizer rule for type coercion (binary operations only) #3222

Merged
merged 8 commits into from
Sep 6, 2022
4 changes: 4 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
use datafusion_optimizer::type_coercion::TypeCoercion;
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
Expand Down Expand Up @@ -1433,6 +1434,9 @@ impl SessionState {
}
rules.push(Arc::new(ReduceOuterJoin::new()));
rules.push(Arc::new(FilterPushDown::new()));
// we do type coercion after filter push down so that we don't push CAST filters to Parquet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smart move, that would have been a hard bug to find!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused about this comment and explain why do the type coercion after the filter push down optimizer rule.

I think the type coercion rule should be done in preview stage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, Filter expr: FLOAT32(C1) < FLOAT64(16). We should do type coercion first and convert the filter expr to CAST(INT32(C1) AS FLOAT64 < FLOAT64(16) and try to push the new filter expr to the table scan operation.

If you don't do the type coercion first, you will push the expr: FLOAT32(C1) < FLOAT64(16) to table scan, Does this can be applied to the parquet filter or pruning filter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liukun4515 This PR is ready for review now

Yes, this is ready for review now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed #3289 applying TypeCoercion before FilterPushDown. I think the PR would get too large to review if I make those changes here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a partially written ticket (I will post later this week) related to supporting CAST in pruning logic (which is part of what is pushed to parquet). Perhaps this is also related

// until https://github.com/apache/arrow-datafusion/issues/3289 is resolved
rules.push(Arc::new(TypeCoercion::new()));
rules.push(Arc::new(LimitPushDown::new()));
rules.push(Arc::new(SingleDistinctToGroupBy::new()));

Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
name += "END";
Ok(name)
}
Expr::Cast { expr, data_type } => {
let expr = create_physical_name(expr, false)?;
Ok(format!("CAST({} AS {:?})", expr, data_type))
Expr::Cast { expr, .. } => {
// CAST does not change the expression name
create_physical_name(expr, false)
}
Expr::TryCast { expr, data_type } => {
let expr = create_physical_name(expr, false)?;
Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
Expr::TryCast { expr, .. } => {
// CAST does not change the expression name
create_physical_name(expr, false)
}
Expr::Not(expr) => {
let expr = create_physical_name(expr, false)?;
Expand Down
16 changes: 8 additions & 8 deletions datafusion/core/tests/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,14 @@ async fn test_fn_substr() -> Result<()> {
async fn test_cast() -> Result<()> {
let expr = cast(col("b"), DataType::Float64);
let expected = vec![
"+-------------------------+",
"| CAST(test.b AS Float64) |",
"+-------------------------+",
"| 1 |",
"| 10 |",
"| 10 |",
"| 100 |",
"+-------------------------+",
"+--------+",
"| test.b |",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the original header better?
@alamb @andygrove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't think seeing the cast in the column name adds much value. Also no cast in the subject is consistent with postgres:

alamb=# select cast(1 as int);
 int4 
------
    1
(1 row)

alamb=# select cast(i as int) from foo;
 i 
---
 1
 2
 0
(3 rows)

"+--------+",
"| 1 |",
"| 10 |",
"| 10 |",
"| 100 |",
"+--------+",
];

assert_fn_batches!(expr, expected);
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/parquet_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ impl ContextWithParquet {
let pretty_input = pretty_format_batches(&input).unwrap().to_string();

let logical_plan = self.ctx.optimize(&logical_plan).expect("optimizing plan");

let physical_plan = self
.ctx
.create_physical_plan(&logical_plan)
Expand Down
29 changes: 24 additions & 5 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,11 @@ async fn csv_query_external_table_sum() {
"SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------------------------------------------+-------------------------------------------+",
"| SUM(CAST(aggregate_test_100.c7 AS Int64)) | SUM(CAST(aggregate_test_100.c8 AS Int64)) |",
"+-------------------------------------------+-------------------------------------------+",
"| 13060 | 3017641 |",
"+-------------------------------------------+-------------------------------------------+",
"+----------------------------+----------------------------+",
"| SUM(aggregate_test_100.c7) | SUM(aggregate_test_100.c8) |",
"+----------------------------+----------------------------+",
"| 13060 | 3017641 |",
"+----------------------------+----------------------------+",
];
assert_batches_eq!(expected, &actual);
}
Expand Down Expand Up @@ -555,6 +555,7 @@ async fn csv_query_count_one() {
}

#[tokio::test]
#[ignore] // https://github.com/apache/arrow-datafusion/issues/3353
async fn csv_query_approx_count() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
Expand All @@ -571,6 +572,24 @@ async fn csv_query_approx_count() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_approx_count_dupe_expr_aliased() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql =
"SELECT approx_distinct(c9) a, approx_distinct(c9) b FROM aggregate_test_100";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+-----+",
"| a | b |",
"+-----+-----+",
"| 100 | 100 |",
"+-----+-----+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

// This test executes the APPROX_PERCENTILE_CONT aggregation against the test
// data, asserting the estimated quantiles are ±5% their actual values.
//
Expand Down
64 changes: 32 additions & 32 deletions datafusion/core/tests/sql/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ async fn avro_query() {
let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+-----------------------------------------+",
"| id | CAST(alltypes_plain.string_col AS Utf8) |",
"+----+-----------------------------------------+",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"+----+-----------------------------------------+",
"+----+---------------------------+",
"| id | alltypes_plain.string_col |",
"+----+---------------------------+",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"+----+---------------------------+",
];

assert_batches_eq!(expected, &actual);
Expand Down Expand Up @@ -84,26 +84,26 @@ async fn avro_query_multiple_files() {
let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+-----------------------------------------+",
"| id | CAST(alltypes_plain.string_col AS Utf8) |",
"+----+-----------------------------------------+",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"+----+-----------------------------------------+",
"+----+---------------------------+",
"| id | alltypes_plain.string_col |",
andygrove marked this conversation as resolved.
Show resolved Hide resolved
"+----+---------------------------+",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"| 4 | 0 |",
"| 5 | 1 |",
"| 6 | 0 |",
"| 7 | 1 |",
"| 2 | 0 |",
"| 3 | 1 |",
"| 0 | 0 |",
"| 1 | 1 |",
"+----+---------------------------+",
];

assert_batches_eq!(expected, &actual);
Expand Down
106 changes: 53 additions & 53 deletions datafusion/core/tests/sql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+------------------------------------------+",
"| CAST(Float64(1.23) AS Decimal128(10, 4)) |",
"+------------------------------------------+",
"| 1.2300 |",
"+------------------------------------------+",
"+---------------+",
"| Float64(1.23) |",
"+---------------+",
"| 1.2300 |",
"+---------------+",
];
assert_batches_eq!(expected, &actual);

Expand All @@ -42,11 +42,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+---------------------------------------------------------------------+",
"| CAST(CAST(Float64(1.23) AS Decimal128(10, 3)) AS Decimal128(10, 4)) |",
"+---------------------------------------------------------------------+",
"| 1.2300 |",
"+---------------------------------------------------------------------+",
"+---------------+",
"| Float64(1.23) |",
"+---------------+",
"| 1.2300 |",
"+---------------+",
];
assert_batches_eq!(expected, &actual);

Expand All @@ -57,11 +57,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+--------------------------------------------+",
"| CAST(Float64(1.2345) AS Decimal128(24, 2)) |",
"+--------------------------------------------+",
"| 1.23 |",
"+--------------------------------------------+",
"+-----------------+",
"| Float64(1.2345) |",
"+-----------------+",
"| 1.23 |",
"+-----------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -550,25 +550,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+----------------------------------------------------------------+",
"| decimal_simple.c1 / CAST(Float64(0.00001) AS Decimal128(5, 5)) |",
"+----------------------------------------------------------------+",
"| 1.000000000000 |",
"| 2.000000000000 |",
"| 2.000000000000 |",
"| 3.000000000000 |",
"| 3.000000000000 |",
"| 3.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"+----------------------------------------------------------------+",
"+--------------------------------------+",
"| decimal_simple.c1 / Float64(0.00001) |",
"+--------------------------------------+",
"| 1.000000000000 |",
"| 2.000000000000 |",
"| 2.000000000000 |",
"| 3.000000000000 |",
"| 3.000000000000 |",
"| 3.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 4.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"| 5.000000000000 |",
"+--------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -609,25 +609,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+----------------------------------------------------------------+",
"| decimal_simple.c5 % CAST(Float64(0.00001) AS Decimal128(5, 5)) |",
"+----------------------------------------------------------------+",
"| 0.0000040 |",
"| 0.0000050 |",
"| 0.0000090 |",
"| 0.0000020 |",
"| 0.0000050 |",
"| 0.0000010 |",
"| 0.0000040 |",
"| 0.0000000 |",
"| 0.0000000 |",
"| 0.0000040 |",
"| 0.0000020 |",
"| 0.0000080 |",
"| 0.0000030 |",
"| 0.0000080 |",
"| 0.0000000 |",
"+----------------------------------------------------------------+",
"+--------------------------------------+",
"| decimal_simple.c5 % Float64(0.00001) |",
"+--------------------------------------+",
"| 0.0000040 |",
"| 0.0000050 |",
"| 0.0000090 |",
"| 0.0000020 |",
"| 0.0000050 |",
"| 0.0000010 |",
"| 0.0000040 |",
"| 0.0000000 |",
"| 0.0000000 |",
"| 0.0000040 |",
"| 0.0000020 |",
"| 0.0000080 |",
"| 0.0000030 |",
"| 0.0000080 |",
"| 0.0000000 |",
"+--------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) - #lineitem.l_discount)]]\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS Float64) - #lineitem.l_discount)]]\
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
Expand All @@ -663,7 +663,7 @@ order by
\n Filter: #lineitem.l_returnflag = Utf8(\"R\")\
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[#lineitem.l_returnflag = Utf8(\"R\")]\
\n TableScan: nation projection=[n_nationkey, n_name]";
assert_eq!(format!("{:?}", plan.unwrap()), expected);
assert_eq!(expected, format!("{:?}", plan.unwrap()),);

Ok(())
}
Expand Down Expand Up @@ -694,7 +694,7 @@ async fn test_physical_plan_display_indent() {
" RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 9000)",
" AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]",
" CoalesceBatchesExec: target_batch_size=4096",
" FilterExec: c12@1 < CAST(10 AS Float64)",
" FilterExec: c12@1 < 10",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the physical plan, which no longer contains a cast here because the logical plan optimized out the cast of a literal value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 -- which I think is a good example of the value of this pass

" RepartitionExec: partitioning=RoundRobinBatch(9000)",
" CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c12]",
];
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ async fn query_not() -> Result<()> {
async fn csv_query_sum_cast() {
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
// c8 = i32; c9 = i64
let sql = "SELECT c8 + c9 FROM aggregate_test_100";
// c8 = i32; c6 = i64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change due to the fact that #3359 changed the type of c9 so it was no longer i64 but u64

let sql = "SELECT c8 + c6 FROM aggregate_test_100";
// check that the physical and logical schemas are equal
execute(&ctx, sql).await;
}
Expand Down
Loading