From 4d33ee072272f9e8876ea05f0c069b2e9977835c Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Mon, 24 Oct 2022 11:10:43 +0800 Subject: [PATCH] [SPARK-36114][SQL] Support subqueries with correlated non-equality predicates ### What changes were proposed in this pull request? This PR supports correlated non-equality predicates in subqueries. It leverages the DecorrelateInnerQuery framework to decorrelate subqueries with non-equality predicates. DecorrelateInnerQuery inserts domain joins in the query plan and the rule RewriteCorrelatedScalarSubquery rewrites the domain joins into actual joins with the outer query. Note, correlated non-equality predicates can lead to query plans with non-equality join conditions, which may be planned as a broadcast NL join or cartesian product. ### Why are the changes needed? To improve subquery support in Spark. ### Does this PR introduce _any_ user-facing change? Yes. Before this PR, Spark does not allow correlated non-equality predicates in subqueries. For example: ```sql SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1 ``` This will throw an exception: `Correlated column is not allowed in a non-equality predicate` After this PR, this query can run successfully. ### How was this patch tested? Unit tests and SQL query tests. Closes #38135 from allisonwang-db/spark-36114-non-equality-pred. Authored-by: allisonwang-db Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 7 +- .../analysis/AnalysisErrorSuite.scala | 2 +- .../sql-tests/inputs/join-lateral.sql | 3 + .../scalar-subquery-select.sql | 45 ++++++++ .../sql-tests/results/join-lateral.sql.out | 9 ++ .../scalar-subquery-select.sql.out | 107 ++++++++++++++++++ .../sql-tests/results/udf/udf-except.sql.out | 17 +-- .../org/apache/spark/sql/SubquerySuite.scala | 59 ++++------ 8 files changed, 195 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4346f51b613a2..cad036a34e97c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1066,7 +1066,12 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // 1 | 2 | 4 // and the plan after rewrite will give the original query incorrect results. def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = { - if (predicates.nonEmpty) { + // Correlated non-equality predicates are only supported with the decorrelate + // inner query framework. Currently we only use this new framework for scalar + // and lateral subqueries. + val allowNonEqualityPredicates = + SQLConf.get.decorrelateInnerQueryEnabled && (isScalar || isLateral) + if (!allowNonEqualityPredicates && predicates.nonEmpty) { // Report a non-supported case as an exception p.failAnalysis( errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8b71bb05550a6..c44a0852b85c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -917,7 +917,7 @@ class AnalysisErrorSuite extends AnalysisTest { (And($"a" === $"c", Cast($"d", IntegerType) === $"c"), "CAST(d#x AS INT) = outer(c#x)")) conditions.foreach { case (cond, msg) => val plan = Project( - ScalarSubquery( + Exists( Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil, Filter(cond, t1)) ).as("sub") :: Nil, diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql index fc5776c46afdd..dc1a35072728f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql @@ -44,6 +44,9 @@ SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = t2.c1); -- lateral join with correlated non-equality predicates SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c2 < t2.c2); +-- SPARK-36114: lateral join with aggregation and correlated non-equality predicates +SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2); + -- lateral join can reference preceding FROM clause items SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2); -- expect error: cannot resolve `t2.c1` diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql index b999d1723c911..6d673f149cc95 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql @@ -190,3 +190,48 @@ SELECT c1, ( -- Multi-value subquery error SELECT (SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t) AS b; + +-- SPARK-36114: Support correlated non-equality predicates +CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2)); +CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3)); + +-- Neumann example Q2 +CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES + (0, 'A', 'CS', 2022), + (1, 'B', 'CS', 2022), + (2, 'C', 'Math', 2022)); +CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES + (0, 'C1', 'CS', 4, 2020), + (0, 'C2', 'CS', 3, 2021), + (1, 'C1', 'CS', 2, 2020), + (1, 'C2', 'CS', 1, 2021)); + +SELECT students.name, exams.course +FROM students, exams +WHERE students.id = exams.sid + AND (students.major = 'CS' OR students.major = 'Games Eng') + AND exams.grade >= ( + SELECT avg(exams.grade) + 1 + FROM exams + WHERE students.id = exams.sid + OR (exams.curriculum = students.major AND students.year > exams.date)); + +-- Correlated non-equality predicates +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1; +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1; + +-- Correlated non-equality predicates with the COUNT bug. +SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1; + +-- Correlated equality predicates that are not supported after SPARK-35080 +SELECT c, ( + SELECT count(*) + FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c) + WHERE t1.c = substring(t2.c, 1, 1) +) FROM (VALUES ('a'), ('b')) t1(c); + +SELECT c, ( + SELECT count(*) + FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b) + WHERE a + b = c +) FROM (VALUES (6)) t2(c); diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out index be07ba7bd9a1e..34c0543dfdda8 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out @@ -272,6 +272,15 @@ struct 1 2 3 +-- !query +SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2) +-- !query schema +struct +-- !query output +0 1 3 +1 2 3 + + -- !query SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out index d1e56786207ed..38ab365ef6941 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out @@ -433,3 +433,110 @@ org.apache.spark.SparkException "fragment" : "(SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t)" } ] } + + +-- !query +CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES + (0, 'A', 'CS', 2022), + (1, 'B', 'CS', 2022), + (2, 'C', 'Math', 2022)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES + (0, 'C1', 'CS', 4, 2020), + (0, 'C2', 'CS', 3, 2021), + (1, 'C1', 'CS', 2, 2020), + (1, 'C2', 'CS', 1, 2021)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT students.name, exams.course +FROM students, exams +WHERE students.id = exams.sid + AND (students.major = 'CS' OR students.major = 'Games Eng') + AND exams.grade >= ( + SELECT avg(exams.grade) + 1 + FROM exams + WHERE students.id = exams.sid + OR (exams.curriculum = students.major AND students.year > exams.date)) +-- !query schema +struct +-- !query output +A C1 + + +-- !query +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1 +-- !query schema +struct +-- !query output +2 +NULL + + +-- !query +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1 +-- !query schema +struct +-- !query output +2 +3 + + +-- !query +SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1 +-- !query schema +struct +-- !query output +0 +2 + + +-- !query +SELECT c, ( + SELECT count(*) + FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c) + WHERE t1.c = substring(t2.c, 1, 1) +) FROM (VALUES ('a'), ('b')) t1(c) +-- !query schema +struct +-- !query output +a 2 +b 1 + + +-- !query +SELECT c, ( + SELECT count(*) + FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b) + WHERE a + b = c +) FROM (VALUES (6)) t2(c) +-- !query schema +struct +-- !query output +6 4 diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out index f532b0d41e344..14ecf98c7a831 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out @@ -97,19 +97,6 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v)) FROM t2 WHERE t2.k = t1.k) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - "messageParameters" : { - "treeNode" : "(cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\nFilter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\n+- SubqueryAlias t2\n +- View (`t2`, [k#x,v#x])\n +- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]\n +- Project [k#x, v#x]\n +- SubqueryAlias t2\n +- LocalRelation [k#x, v#x]\n" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 39, - "stopIndex" : 141, - "fragment" : "SELECT udf(max(udf(t2.v)))\n FROM t2\n WHERE udf(t2.k) = udf(t1.k)" - } ] -} +two diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index ecb4bfd0ec41b..9d326b92b939f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -66,6 +66,11 @@ class SubquerySuite extends QueryTest t.createOrReplaceTempView("t") } + private def checkNumJoins(plan: LogicalPlan, numJoins: Int): Unit = { + val joins = plan.collect { case j: Join => j } + assert(joins.size == numJoins) + } + test("SPARK-18854 numberedTreeString for subquery") { val df = sql("select * from range(10) where id not in " + "(select id from range(2) union all select id from range(2))") @@ -562,17 +567,10 @@ class SubquerySuite extends QueryTest } test("non-equal correlated scalar subquery") { - val exception = intercept[AnalysisException] { - sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") - } - checkErrorMatchPVals( - exception, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "select sum(b) from l l2 where l2.a < l1.a", start = 11, stop = 51)) + checkAnswer( + sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1"), + Seq(Row(1, null), Row(1, null), Row(2, 4), Row(2, 4), Row(3, 6), Row(null, null), + Row(null, null), Row(6, 9))) } test("disjunctive correlated scalar subquery") { @@ -2105,25 +2103,17 @@ class SubquerySuite extends QueryTest } } - test("SPARK-38155: disallow distinct aggregate in lateral subqueries") { + test("SPARK-36114: distinct aggregate in lateral subqueries") { withTempView("t1", "t2") { Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") - val exception = intercept[AnalysisException] { - sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)") - } - checkErrorMatchPVals( - exception, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1", start = 31, stop = 73)) + checkAnswer( + sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)"), + Row(0, 1, 2) :: Nil) } } - test("SPARK-38180: allow safe cast expressions in correlated equality conditions") { + test("SPARK-38180, SPARK-36114: allow safe cast expressions in correlated equality conditions") { withTempView("t1", "t2") { Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2") @@ -2139,19 +2129,14 @@ class SubquerySuite extends QueryTest |FROM (SELECT CAST(c1 AS STRING) a FROM t1) |""".stripMargin), Row(5) :: Row(null) :: Nil) - val exception1 = intercept[AnalysisException] { - sql( - """SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) - |FROM (SELECT CAST(c1 AS SHORT) a FROM t1)""".stripMargin) - } - checkErrorMatchPVals( - exception1, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a", start = 8, stop = 57)) + // SPARK-36114: we now allow non-safe cast expressions in correlated predicates. + val df = sql( + """SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) + |FROM (SELECT CAST(c1 AS SHORT) a FROM t1) + |""".stripMargin) + checkAnswer(df, Row(5) :: Row(null) :: Nil) + // The optimized plan should have one left outer join and one domain (inner) join. + checkNumJoins(df.queryExecution.optimizedPlan, 2) } }