Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,6 @@ impl TableProvider for DataFrame {

#[cfg(test)]
mod tests {
use arrow::array::Int32Array;
use std::vec;

use super::*;
Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,10 @@ order by s_name;
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")]
Projection: lineitem.l_partkey, lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]
Filter: lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
Filter: lineitem.l_shipdate >= Date32("8766")
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"#
.to_string();
assert_eq!(actual, expected);

Expand Down Expand Up @@ -393,8 +393,8 @@ order by cntrycode;"#;
TableScan: orders projection=[o_custkey]
Projection: AVG(customer.c_acctbal) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
.to_string();
assert_eq!(actual, expected);

Expand Down Expand Up @@ -453,7 +453,7 @@ order by value desc;
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
Expand Down
4 changes: 4 additions & 0 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ impl Optimizer {
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Arc::new(SubqueryFilterToJoin::new()),
// simplify expressions does not simplify expressions in subqueries, so we
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a solution to resolve this.
But We also has a way to resolve this issue by supporting the subquery in the rewriter of simplify expression.
Why not support the subquery in the rewriter of simplify expression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR just reverts the regression and adds a regression test, but I agree, would be better to handle it in simplify_expressions. I filed #3770 for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree #3770 is a better long term solution

// run it again after running the optimizations that potentially converted
// subqueries to joins
Arc::new(SimplifyExpressions::new()),
Arc::new(EliminateFilter::new()),
Arc::new(ReduceCrossJoin::new()),
Arc::new(CommonSubexprEliminate::new()),
Expand Down
22 changes: 22 additions & 0 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,28 @@ fn case_when() -> Result<()> {
Ok(())
}

#[test]
fn subquery_filter_with_cast() -> Result<()> {
// regression test for https://github.com/apache/arrow-datafusion/issues/3760
let sql = "SELECT col_int32 FROM test \
WHERE col_int32 > (\
SELECT AVG(col_int32) FROM test \
WHERE col_utf8 BETWEEN '2002-05-08' \
AND (cast('2002-05-08' as date) + interval '5 days')\
)";
let plan = test_sql(sql)?;
let expected =
"Projection: test.col_int32\n Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\
\n CrossJoin:\
\n TableScan: test projection=[col_int32]\
\n Projection: AVG(test.col_int32) AS __value, alias=__sq_1\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\
\n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn case_when_aggregate() -> Result<()> {
let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";
Expand Down