From b011f3c5725b632d3668084c02d4b43fd57013a2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 7 Oct 2022 16:24:33 -0600 Subject: [PATCH 1/4] add failing test --- .../optimizer/tests/integration-test.rs | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index f6fe685ee282..a2bc28cd78fe 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -52,6 +52,26 @@ fn case_when() -> Result<()> { Ok(()) } +#[test] +fn subquery_filter_with_cast() -> Result<()> { + let sql = "SELECT col_int32 FROM test \ + WHERE col_int32 > (\ + SELECT AVG(col_int32) FROM test \ + WHERE col_utf8 BETWEEN '2002-05-08' \ + AND (cast('2002-05-08' as date) + interval '5 days')\ + )"; + let plan = test_sql(sql)?; + let expected = + "Projection: test.col_int32\n Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\ + \n CrossJoin:\n TableScan: test projection=[col_int32]\ + \n Projection: AVG(test.col_int32) AS __value, alias=__sq_1\ + \n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\ + \n Filter: test.col_utf8 BETWEEN Utf8(\"2002-05-08\") AND Utf8(\"2002-05-13\")\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn case_when_aggregate() -> Result<()> { let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; From c7cdc13841893409aa19b0db3f4583d3102211f3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 7 Oct 2022 16:51:08 -0600 Subject: [PATCH 2/4] update tests --- datafusion/core/tests/sql/subqueries.rs | 12 ++++++------ datafusion/optimizer/src/optimizer.rs | 4 ++++ datafusion/optimizer/tests/integration-test.rs | 5 +++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index f91018d8bf64..a5b246be4f0a 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -336,10 +336,10 @@ order by s_name; Projection: part.p_partkey AS p_partkey, alias=__sq_1 Filter: part.p_name LIKE Utf8("forest%") TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] - Projection: lineitem.l_partkey, lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3 + Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3 Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] - Filter: lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32) - TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"# + Filter: lineitem.l_shipdate >= Date32("8766") + TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"# .to_string(); assert_eq!(actual, expected); @@ -393,8 +393,8 @@ order by cntrycode;"#; TableScan: orders projection=[o_custkey] Projection: AVG(customer.c_acctbal) AS __value, alias=__sq_1 Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] - Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) - TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"# + Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) + TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"# .to_string(); assert_eq!(actual, expected); @@ -453,7 +453,7 @@ order by value desc; TableScan: supplier projection=[s_suppkey, s_nationkey] Filter: nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] - Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1 + Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1 Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index aa10cd8a7dc2..87e4d1ffcd13 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -144,6 +144,10 @@ impl Optimizer { Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()), Arc::new(SubqueryFilterToJoin::new()), + // simplify expressions does not simplify expressions in subqueries, so we + // run it again after running the optimizations that potentially converted + // subqueries to joins + Arc::new(SimplifyExpressions::new()), Arc::new(EliminateFilter::new()), Arc::new(ReduceCrossJoin::new()), Arc::new(CommonSubexprEliminate::new()), diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index a2bc28cd78fe..2e27ab94c925 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -63,10 +63,11 @@ fn subquery_filter_with_cast() -> Result<()> { let plan = test_sql(sql)?; let expected = "Projection: test.col_int32\n Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\ - \n CrossJoin:\n TableScan: test projection=[col_int32]\ + \n CrossJoin:\ + \n TableScan: test projection=[col_int32]\ \n Projection: AVG(test.col_int32) AS __value, alias=__sq_1\ \n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\ - \n Filter: test.col_utf8 BETWEEN Utf8(\"2002-05-08\") AND Utf8(\"2002-05-13\")\ + \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{:?}", plan)); Ok(()) From 42502e033889474bddb24de0ab0921a075b8bd3e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 7 Oct 2022 16:53:55 -0600 Subject: [PATCH 3/4] document issue in test --- datafusion/optimizer/tests/integration-test.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 2e27ab94c925..12a5b4447531 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -54,6 +54,7 @@ fn case_when() -> Result<()> { #[test] fn subquery_filter_with_cast() -> Result<()> { + // regression test for https://github.com/apache/arrow-datafusion/issues/3760 let sql = "SELECT col_int32 FROM test \ WHERE col_int32 > (\ SELECT AVG(col_int32) FROM test \ From 0ee1c838a99a635bd0b444e24cf2689e70157b37 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 7 Oct 2022 17:34:09 -0600 Subject: [PATCH 4/4] fix merge conflict --- datafusion/core/src/dataframe.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 06768b5631ca..a5caad176558 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -827,7 +827,6 @@ impl TableProvider for DataFrame { #[cfg(test)] mod tests { - use arrow::array::Int32Array; use std::vec; use super::*;