Skip to content

Commit

Permalink
[SPARK-40618][SQL] Fix bug in MergeScalarSubqueries rule with nested …
Browse files Browse the repository at this point in the history
…subqueries using reference tracking

### What changes were proposed in this pull request?
This PR reverts the previous fix apache#38052 and adds subquery reference tracking to `MergeScalarSubqueries` to restore previous functionality of merging independent nested subqueries.

### Why are the changes needed?
Restore previous functionality but fix the bug discovered in https://issues.apache.org/jira/browse/SPARK-40618.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing and new UTs.

Closes apache#38093 from peter-toth/SPARK-40618-fix-mergescalarsubqueries.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
peter-toth authored and SandishKumarHN committed Dec 12, 2022
1 parent 00aa52f commit 6c6dd01
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -126,8 +127,14 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
* merged as there can be subqueries that are different ([[checkIdenticalPlans]] is
* false) due to an extra [[Project]] node in one of them. In that case
* `attributes.size` remains 1 after merging, but the merged flag becomes true.
* @param references A set of subquery indexes in the cache to track all (including transitive)
* nested subqueries.
*/
case class Header(attributes: Seq[Attribute], plan: LogicalPlan, merged: Boolean)
case class Header(
attributes: Seq[Attribute],
plan: LogicalPlan,
merged: Boolean,
references: Set[Int])

private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
val cache = ArrayBuffer.empty[Header]
Expand Down Expand Up @@ -166,26 +173,39 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
// "Header".
private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): (Int, Int) = {
val output = plan.output.head
cache.zipWithIndex.collectFirst(Function.unlift { case (header, subqueryIndex) =>
checkIdenticalPlans(plan, header.plan).map { outputMap =>
val mappedOutput = mapAttributes(output, outputMap)
val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
subqueryIndex -> headerIndex
}.orElse(tryMergePlans(plan, header.plan).map {
case (mergedPlan, outputMap) =>
val references = mutable.HashSet.empty[Int]
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) {
case ssr: ScalarSubqueryReference =>
references += ssr.subqueryIndex
references ++= cache(ssr.subqueryIndex).references
ssr
}

cache.zipWithIndex.collectFirst(Function.unlift {
case (header, subqueryIndex) if !references.contains(subqueryIndex) =>
checkIdenticalPlans(plan, header.plan).map { outputMap =>
val mappedOutput = mapAttributes(output, outputMap)
var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
val newHeaderAttributes = if (headerIndex == -1) {
headerIndex = header.attributes.size
header.attributes :+ mappedOutput
} else {
header.attributes
}
cache(subqueryIndex) = Header(newHeaderAttributes, mergedPlan, true)
val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
subqueryIndex -> headerIndex
})
}.orElse{
tryMergePlans(plan, header.plan).map {
case (mergedPlan, outputMap) =>
val mappedOutput = mapAttributes(output, outputMap)
var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
val newHeaderAttributes = if (headerIndex == -1) {
headerIndex = header.attributes.size
header.attributes :+ mappedOutput
} else {
header.attributes
}
cache(subqueryIndex) =
Header(newHeaderAttributes, mergedPlan, true, header.references ++ references)
subqueryIndex -> headerIndex
}
}
case _ => None
}).getOrElse {
cache += Header(Seq(output), plan, false)
cache += Header(Seq(output), plan, false, references.toSet)
cache.length - 1 -> 0
}
}
Expand All @@ -210,12 +230,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = {
checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse(
(newPlan, cachedPlan) match {
case (_, _) if newPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) ||
cachedPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) =>
// Subquery expressions with nested subquery expressions within are not supported for now.
// TODO: support this optimization by collecting the transitive subquery references in the
// new plan and recording them in order to suppress merging the new plan into those.
None
case (np: Project, cp: Project) =>
tryMergePlans(np.child, cp.child).map { case (mergedChild, outputMap) =>
val (mergedProjectList, newOutputMap) =
Expand Down
35 changes: 29 additions & 6 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@ class SubquerySuite extends QueryTest
}
}

test("SPARK-40618: Do not merge scalar subqueries with nested subqueries inside") {
test("Merge non-correlated scalar subqueries from different parent plans") {
Seq(false, true).foreach { enableAQE =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
Expand Down Expand Up @@ -2283,13 +2283,13 @@ class SubquerySuite extends QueryTest
}

if (enableAQE) {
assert(subqueryIds.size == 4, "Missing or unexpected SubqueryExec in the plan")
assert(reusedSubqueryIds.size == 2,
"Missing or unexpected reused ReusedSubqueryExec in the plan")
} else {
assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan")
assert(reusedSubqueryIds.size == 3,
"Missing or unexpected reused ReusedSubqueryExec in the plan")
} else {
assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan")
assert(reusedSubqueryIds.size == 4,
"Missing or unexpected reused ReusedSubqueryExec in the plan")
}
}
}
Expand Down Expand Up @@ -2426,9 +2426,32 @@ class SubquerySuite extends QueryTest
// This test contains a subquery expression with another subquery expression nested inside.
// It acts as a regression test to ensure that the MergeScalarSubqueries rule does not attempt
// to merge them together.
withTable("t") {
withTable("t", "t2") {
sql("create table t(col int) using csv")
checkAnswer(sql("select(select sum((select sum(col) from t)) from t)"), Row(null))

checkAnswer(sql(
"""
|select
| (select sum(
| (select sum(
| (select sum(col) from t))
| from t))
| from t)
|""".stripMargin),
Row(null))

sql("create table t2(col int) using csv")
checkAnswer(sql(
"""
|select
| (select sum(
| (select sum(
| (select sum(col) from t))
| from t2))
| from t)
|""".stripMargin),
Row(null))
}
}
}

0 comments on commit 6c6dd01

Please sign in to comment.