Skip to content

Commit

Permalink
[SPARK-36114][SQL] Support subqueries with correlated non-equality pr…
Browse files Browse the repository at this point in the history
…edicates

<!--
Thanks for sending a pull request!  Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
  2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
  4. Be sure to keep the PR description updated to reflect all changes.
  5. Please write your PR title to summarize what this PR proposes.
  6. If possible, provide a concise example to reproduce the issue for a faster review.
  7. If you want to add a new configuration, please read the guideline first for naming configurations in
     'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
  8. If you want to add or modify an error type or message, please read the guideline first in
     'core/src/main/resources/error/README.md'.
-->

### What changes were proposed in this pull request?
<!--
Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
  1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
  2. If you fix some SQL features, you can provide some references of other DBMSes.
  3. If there is design documentation, please add the link.
  4. If there is a discussion in the mailing list, please add the link.
-->
This PR supports correlated non-equality predicates in subqueries. It leverages the DecorrelateInnerQuery framework to decorrelate subqueries with non-equality predicates. DecorrelateInnerQuery inserts domain joins in the query plan and the rule RewriteCorrelatedScalarSubquery rewrites the domain joins into actual joins with the outer query.

Note, correlated non-equality predicates can lead to query plans with non-equality join conditions, which may be planned as a broadcast NL join or cartesian product.

### Why are the changes needed?
<!--
Please clarify why the changes are needed. For instance,
  1. If you propose a new API, clarify the use case for a new API.
  2. If you fix a bug, you can clarify why it is a bug.
-->
To improve subquery support in Spark.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such as the documentation fix.
If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
If possible, please also clarify if this is a user-facing change compared to the released Spark versions or within the unreleased branches such as master.
If no, write 'No'.
-->
Yes. Before this PR, Spark does not allow correlated non-equality predicates in subqueries.
For example:
```sql
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1
```
This will throw an exception: `Correlated column is not allowed in a non-equality predicate`

After this PR, this query can run successfully.

### How was this patch tested?
<!--
If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
If tests were not added, please describe why they were not added and/or why it was difficult to add.
If benchmark tests were added, please run the benchmarks in GitHub Actions for the consistent environment, and the instructions could accord to: https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
-->
Unit tests and SQL query tests.

Closes #38135 from allisonwang-db/spark-36114-non-equality-pred.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
allisonwang-db authored and cloud-fan committed Oct 24, 2022
1 parent 74c8264 commit 4d33ee0
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 54 deletions.
Expand Up @@ -1066,7 +1066,12 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
// 1 | 2 | 4
// and the plan after rewrite will give the original query incorrect results.
def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = {
if (predicates.nonEmpty) {
// Correlated non-equality predicates are only supported with the decorrelate
// inner query framework. Currently we only use this new framework for scalar
// and lateral subqueries.
val allowNonEqualityPredicates =
SQLConf.get.decorrelateInnerQueryEnabled && (isScalar || isLateral)
if (!allowNonEqualityPredicates && predicates.nonEmpty) {
// Report a non-supported case as an exception
p.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
Expand Down
Expand Up @@ -917,7 +917,7 @@ class AnalysisErrorSuite extends AnalysisTest {
(And($"a" === $"c", Cast($"d", IntegerType) === $"c"), "CAST(d#x AS INT) = outer(c#x)"))
conditions.foreach { case (cond, msg) =>
val plan = Project(
ScalarSubquery(
Exists(
Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil,
Filter(cond, t1))
).as("sub") :: Nil,
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
Expand Up @@ -44,6 +44,9 @@ SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = t2.c1);
-- lateral join with correlated non-equality predicates
SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c2 < t2.c2);

-- SPARK-36114: lateral join with aggregation and correlated non-equality predicates
SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2);

-- lateral join can reference preceding FROM clause items
SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2);
-- expect error: cannot resolve `t2.c1`
Expand Down
Expand Up @@ -190,3 +190,48 @@ SELECT c1, (

-- Multi-value subquery error
SELECT (SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t) AS b;

-- SPARK-36114: Support correlated non-equality predicates
CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2));
CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3));

-- Neumann example Q2
CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES
(0, 'A', 'CS', 2022),
(1, 'B', 'CS', 2022),
(2, 'C', 'Math', 2022));
CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES
(0, 'C1', 'CS', 4, 2020),
(0, 'C2', 'CS', 3, 2021),
(1, 'C1', 'CS', 2, 2020),
(1, 'C2', 'CS', 1, 2021));

SELECT students.name, exams.course
FROM students, exams
WHERE students.id = exams.sid
AND (students.major = 'CS' OR students.major = 'Games Eng')
AND exams.grade >= (
SELECT avg(exams.grade) + 1
FROM exams
WHERE students.id = exams.sid
OR (exams.curriculum = students.major AND students.year > exams.date));

-- Correlated non-equality predicates
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1;
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1;

-- Correlated non-equality predicates with the COUNT bug.
SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1;

-- Correlated equality predicates that are not supported after SPARK-35080
SELECT c, (
SELECT count(*)
FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c)
WHERE t1.c = substring(t2.c, 1, 1)
) FROM (VALUES ('a'), ('b')) t1(c);

SELECT c, (
SELECT count(*)
FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b)
WHERE a + b = c
) FROM (VALUES (6)) t2(c);
Expand Up @@ -272,6 +272,15 @@ struct<c1:int,c2:int,c2:int>
1 2 3


-- !query
SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2)
-- !query schema
struct<c1:int,c2:int,m:int>
-- !query output
0 1 3
1 2 3


-- !query
SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2)
-- !query schema
Expand Down
Expand Up @@ -433,3 +433,110 @@ org.apache.spark.SparkException
"fragment" : "(SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t)"
} ]
}


-- !query
CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2))
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3))
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES
(0, 'A', 'CS', 2022),
(1, 'B', 'CS', 2022),
(2, 'C', 'Math', 2022))
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES
(0, 'C1', 'CS', 4, 2020),
(0, 'C2', 'CS', 3, 2021),
(1, 'C1', 'CS', 2, 2020),
(1, 'C2', 'CS', 1, 2021))
-- !query schema
struct<>
-- !query output



-- !query
SELECT students.name, exams.course
FROM students, exams
WHERE students.id = exams.sid
AND (students.major = 'CS' OR students.major = 'Games Eng')
AND exams.grade >= (
SELECT avg(exams.grade) + 1
FROM exams
WHERE students.id = exams.sid
OR (exams.curriculum = students.major AND students.year > exams.date))
-- !query schema
struct<name:string,course:string>
-- !query output
A C1


-- !query
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1
-- !query schema
struct<scalarsubquery(c1):int>
-- !query output
2
NULL


-- !query
SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1
-- !query schema
struct<scalarsubquery(c1, c2):int>
-- !query output
2
3


-- !query
SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1
-- !query schema
struct<scalarsubquery(c1):bigint>
-- !query output
0
2


-- !query
SELECT c, (
SELECT count(*)
FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c)
WHERE t1.c = substring(t2.c, 1, 1)
) FROM (VALUES ('a'), ('b')) t1(c)
-- !query schema
struct<c:string,scalarsubquery(c):bigint>
-- !query output
a 2
b 1


-- !query
SELECT c, (
SELECT count(*)
FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b)
WHERE a + b = c
) FROM (VALUES (6)) t2(c)
-- !query schema
struct<c:int,scalarsubquery(c):bigint>
-- !query output
6 4
Expand Up @@ -97,19 +97,6 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v))
FROM t2
WHERE t2.k = t1.k)
-- !query schema
struct<>
struct<k:string>
-- !query output
org.apache.spark.sql.AnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
"messageParameters" : {
"treeNode" : "(cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\nFilter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\n+- SubqueryAlias t2\n +- View (`t2`, [k#x,v#x])\n +- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]\n +- Project [k#x, v#x]\n +- SubqueryAlias t2\n +- LocalRelation [k#x, v#x]\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 39,
"stopIndex" : 141,
"fragment" : "SELECT udf(max(udf(t2.v)))\n FROM t2\n WHERE udf(t2.k) = udf(t1.k)"
} ]
}
two
59 changes: 22 additions & 37 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Expand Up @@ -66,6 +66,11 @@ class SubquerySuite extends QueryTest
t.createOrReplaceTempView("t")
}

private def checkNumJoins(plan: LogicalPlan, numJoins: Int): Unit = {
val joins = plan.collect { case j: Join => j }
assert(joins.size == numJoins)
}

test("SPARK-18854 numberedTreeString for subquery") {
val df = sql("select * from range(10) where id not in " +
"(select id from range(2) union all select id from range(2))")
Expand Down Expand Up @@ -562,17 +567,10 @@ class SubquerySuite extends QueryTest
}

test("non-equal correlated scalar subquery") {
val exception = intercept[AnalysisException] {
sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
}
checkErrorMatchPVals(
exception,
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment = "select sum(b) from l l2 where l2.a < l1.a", start = 11, stop = 51))
checkAnswer(
sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1"),
Seq(Row(1, null), Row(1, null), Row(2, 4), Row(2, 4), Row(3, 6), Row(null, null),
Row(null, null), Row(6, 9)))
}

test("disjunctive correlated scalar subquery") {
Expand Down Expand Up @@ -2105,25 +2103,17 @@ class SubquerySuite extends QueryTest
}
}

test("SPARK-38155: disallow distinct aggregate in lateral subqueries") {
test("SPARK-36114: distinct aggregate in lateral subqueries") {
withTempView("t1", "t2") {
Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1")
Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
val exception = intercept[AnalysisException] {
sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)")
}
checkErrorMatchPVals(
exception,
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment = "SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1", start = 31, stop = 73))
checkAnswer(
sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)"),
Row(0, 1, 2) :: Nil)
}
}

test("SPARK-38180: allow safe cast expressions in correlated equality conditions") {
test("SPARK-38180, SPARK-36114: allow safe cast expressions in correlated equality conditions") {
withTempView("t1", "t2") {
Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2")
Expand All @@ -2139,19 +2129,14 @@ class SubquerySuite extends QueryTest
|FROM (SELECT CAST(c1 AS STRING) a FROM t1)
|""".stripMargin),
Row(5) :: Row(null) :: Nil)
val exception1 = intercept[AnalysisException] {
sql(
"""SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a)
|FROM (SELECT CAST(c1 AS SHORT) a FROM t1)""".stripMargin)
}
checkErrorMatchPVals(
exception1,
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
parameters = Map("treeNode" -> "(?s).*"),
sqlState = None,
context = ExpectedContext(
fragment = "SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a", start = 8, stop = 57))
// SPARK-36114: we now allow non-safe cast expressions in correlated predicates.
val df = sql(
"""SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a)
|FROM (SELECT CAST(c1 AS SHORT) a FROM t1)
|""".stripMargin)
checkAnswer(df, Row(5) :: Row(null) :: Nil)
// The optimized plan should have one left outer join and one domain (inner) join.
checkNumJoins(df.queryExecution.optimizedPlan, 2)
}
}

Expand Down

0 comments on commit 4d33ee0

Please sign in to comment.