Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35080][SQL] Only allow a subset of correlated equality predicates when a subquery is aggregated #32179

Closed
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -899,14 +899,72 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
// +- SubqueryAlias t1, `t1`
// +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
// +- LocalRelation [_1#73, _2#74]
def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = {
if (found) {
// SPARK-35080: The same issue can happen to correlated equality predicates when
// they do not guarantee one-to-one mapping between inner and outer attributes.
// For example:
// Table:
// t1(a, b): [(0, 6), (1, 5), (2, 4)]
// t2(c): [(6)]
//
// Query:
// SELECT c, (SELECT COUNT(*) FROM t1 WHERE a + b = c) FROM t2
//
// Original subquery plan:
// Aggregate [count(1)]
// +- Filter ((a + b) = outer(c))
// +- LocalRelation [a, b]
//
// Plan after pulling up correlated predicates:
// Aggregate [a, b] [count(1), a, b]
// +- LocalRelation [a, b]
//
// Plan after rewrite:
// Project [c1, count(1)]
// +- Join LeftOuter ((a + b) = c)
// :- LocalRelation [c]
// +- Aggregate [a, b] [count(1), a, b]
// +- LocalRelation [a, b]
//
// The right hand side of the join transformed from the subquery will output
// count(1) | a | b
// 1 | 0 | 6
// 1 | 1 | 5
// 1 | 2 | 4
// and the plan after rewrite will give the original query incorrect results.
def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = {
if (predicates.nonEmpty) {
// Report a non-supported case as an exception
failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p")
failAnalysis(s"Correlated column is not allowed in predicate " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: drop s

s"${predicates.map(_.sql).mkString}:\n$p")
}
}

var foundNonEqualCorrelatedPred: Boolean = false
def containsAttribute(e: Expression): Boolean = {
e.find(_.isInstanceOf[Attribute]).isDefined
}

// Given a correlated predicate, check if it is either a non-equality predicate or
// equality predicate that does not guarantee one-on-one mapping between inner and
// outer attributes. When the correlated predicate does not contain any attribute
// (i.e. only has outer references), it is supported and should return false. E.G.:
// (a = outer(c)) -> false
// (outer(c) = outer(d)) -> false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have test case for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes added a test case in SubquerySuite

// (a > outer(c)) -> true
// (a + b = outer(c)) -> true
// The last one is true because there can be multiple combinations of (a, b) that
// satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0)
// and (-1, 1) can make the predicate evaluate to true.
def isUnsupportedPredicate(condition: Expression): Boolean = condition match {
// Only allow equality condition with one side being an attribute and another
// side being an expression without attributes from the inner query. Note
// OuterReference is a leaf node and will not be found here.
case Equality(_: Attribute, b) => containsAttribute(b)
case Equality(a, _: Attribute) => containsAttribute(a)
maropu marked this conversation as resolved.
Show resolved Hide resolved
case e @ Equality(_, _) => containsAttribute(e)
case _ => true
}

val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression]

// Simplify the predicates before validating any unsupported correlation patterns in the plan.
AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp {
Expand Down Expand Up @@ -949,22 +1007,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
// The other operator is Join. Filter can be anywhere in a correlated subquery.
case f: Filter =>
val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)

// Find any non-equality correlated predicates
foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists {
case _: EqualTo | _: EqualNullSafe => false
case _ => true
}
unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate)
failOnInvalidOuterReference(f)

// Aggregate cannot host any correlated expressions
// It can be on a correlation path if the correlation contains
// only equality correlated predicates.
// only supported correlated equality predicates.
// It cannot be on a correlation path if the correlation has
// non-equality correlated predicates.
case a: Aggregate =>
failOnInvalidOuterReference(a)
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
failOnUnsupportedCorrelatedPredicate(unsupportedPredicates, a)
Copy link
Member

@dongjoon-hyun dongjoon-hyun Apr 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to cause a compilation error with Scala 2.13. Could you re-check with Scala 2.13, @allisonwang-db ? Please try to add .toSeq.

[error] /home/runner/work/spark/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:1020:46: type mismatch;
[error]  found   : scala.collection.mutable.ArrayBuffer[org.apache.spark.sql.catalyst.expressions.Expression]
[error]  required: Seq[org.apache.spark.sql.catalyst.expressions.Expression]
[error]         failOnUnsupportedCorrelatedPredicate(unsupportedPredicates, a)


// Join can host correlated expressions.
case j @ Join(left, right, joinType, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,4 +767,28 @@ class AnalysisErrorSuite extends AnalysisTest {
"using ordinal position or wrap it in first() (or first_value) if you don't care " +
"which value you get." :: Nil)
}

test("SPARK-35080: Unsupported correlated equality predicates in subquery") {
val a = AttributeReference("a", IntegerType)()
val b = AttributeReference("b", IntegerType)()
val c = AttributeReference("c", IntegerType)()
val t1 = LocalRelation(a, b)
val t2 = LocalRelation(c)
val conditions = Seq(
(abs($"a") === $"c", "abs(a) = outer(c)"),
(abs($"a") <=> $"c", "abs(a) <=> outer(c)"),
($"a" + 1 === $"c", "(a + 1) = outer(c)"),
($"a" + $"b" === $"c", "(a + b) = outer(c)"),
($"a" + $"c" === $"b", "(a + outer(c)) = b"),
(And($"a" === $"c", Cast($"a", IntegerType) === $"c"), "CAST(a AS INT) = outer(c)"))
conditions.foreach { case (cond, msg) =>
val plan = Project(
ScalarSubquery(
Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil,
Filter(cond, t1))
).as("sub") :: Nil,
t2)
assertAnalysisError(plan, s"Correlated column is not allowed in predicate ($msg)" :: Nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v))
FROM t2
WHERE t2.k = t1.k)
-- !query schema
struct<k:string>
struct<>
-- !query output
two
org.apache.spark.sql.AnalysisException
Correlated column is not allowed in predicate (CAST(udf(cast(k as string)) AS STRING) = CAST(udf(cast(outer(k#x) as string)) AS STRING)):
Aggregate [cast(udf(cast(max(cast(udf(cast(v#x as string)) as int)) as string)) as int) AS udf(max(udf(v)))#x]
+- Filter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))
+- SubqueryAlias t2
+- View (`t2`, [k#x,v#x])
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
+- Project [k#x, v#x]
+- SubqueryAlias t2
+- LocalRelation [k#x, v#x]
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1")
}
assert(msg1.getMessage.contains(
"Correlated column is not allowed in a non-equality predicate:"))
"Correlated column is not allowed in predicate (l2.a < outer(l1.a))"))
}

test("disjunctive correlated scalar subquery") {
Expand Down Expand Up @@ -1827,4 +1827,10 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
Row(0, 1, 1) :: Row(1, 2, null) :: Nil)
}
}

test("SPARK-35080: correlated equality predicates contain only outer references") {
checkAnswer(
sql("select c, d, (select count(*) from l where c + 1 = d) from t"),
Row(2, 3.0, 8) :: Row(2, 3.0, 8) :: Row(3, 2.0, 0) :: Row(4, 1.0, 0) :: Nil)
}
}