Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45009][SQL] Decorrelate predicate subqueries in join condition #42725

Closed
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
Set(
"PartitionPruning",
"RewriteSubquery",
"Extract Python UDFs")
"Extract Python UDFs",
"Infer Filters")
Copy link
Contributor Author

@andylam-db andylam-db Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Infer Filters" is not inherently idempotent. We need to exclude it from the idempotency check so that tests will pass.


protected def fixedPoint =
FixedPoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION, OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -159,6 +160,66 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
Project(p.output, Filter(newCond.get, inputPlan))
}

// This case takes care of predicate subqueries in join conditions that are not pushed down
// to the children nodes by [[PushDownPredicates]].
andylam-db marked this conversation as resolved.
Show resolved Hide resolved
case j: Join if j.condition.exists(cond =>
SubqueryExpression.hasInOrCorrelatedExistsSubquery(cond)) &&
conf.getConf(DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION) =>

val optimizeUncorrelatedInSubqueries =
conf.getConf(OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION)
val relevantSubqueries = j.condition.get.collect {
case i: InSubquery if i.query.isCorrelated => i
case i: InSubquery if !i.query.isCorrelated && optimizeUncorrelatedInSubqueries => i
case e: Exists if e.isCorrelated => e
}
if (relevantSubqueries.isEmpty) {
j
} else {
// `subqueriesWithJoinInputReferenceInfo`is of type Seq[(Expression, Boolean, Boolean)]
// (1): Expression, the join predicate containing some predicate subquery we are interested
// in re-writing
// (2): Boolean, whether (1) references the left join input
// (3): Boolean, whether (1) references the right join input
val subqueriesWithJoinInputReferenceInfo = relevantSubqueries.map { e =>
val referenceLeft = e.references.intersect(j.left.outputSet).nonEmpty
val referenceRight = e.references.intersect(j.right.outputSet).nonEmpty
(e, referenceLeft, referenceRight)
}
val subqueriesReferencingBothJoinInputs = subqueriesWithJoinInputReferenceInfo
.filter(i => i._2 && i._3)

// Currently do not support correlated subqueries in the join predicate that reference both
// join inputs
if (subqueriesReferencingBothJoinInputs.nonEmpty) {
throw QueryCompilationErrors.unsupportedCorrelatedSubqueryInJoinConditionError(
subqueriesReferencingBothJoinInputs.map(_._1))
}
val subqueriesReferencingLeft = subqueriesWithJoinInputReferenceInfo.filter(_._2).map(_._1)
val subqueriesReferencingRight = subqueriesWithJoinInputReferenceInfo.filter(_._3).map(_._1)
var newCondition = j.condition.get
val newLeft = subqueriesReferencingLeft.foldLeft(j.left) {
case (p, e) =>
val (newCond, newInputPlan) = rewriteExistentialExpr(Seq(e), p)
// Update the join condition to rewrite the subquery expression
newCondition = newCondition.transform {
case expr if expr.fastEquals(e) => newCond.get
}
newInputPlan
}
val newRight = subqueriesReferencingRight.foldLeft(j.right) {
case (p, e) =>
val (newCond, newInputPlan) = rewriteExistentialExpr(Seq(e), p)
// Update the join condition to rewrite the subquery expression
newCondition = newCondition.transform {
case expr if expr.fastEquals(e) => newCond.get
}
newInputPlan
}
// Remove unwanted exists columns from new existence joins with new Project
Project(j.output, j.copy(left = newLeft, right = newRight, condition = Some(newCondition)))
}

case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
var newChild = u.child
Expand Down Expand Up @@ -371,6 +432,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
} else {
newPlan
}
case j: Join if conf.getConf(DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION) =>
rewriteSubQueries(j)
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
case q: UnaryNode =>
rewriteSubQueries(q)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("expr" -> expr.sql, "dataType" -> dataType.typeName))
}

def unsupportedCorrelatedSubqueryInJoinConditionError(
unsupportedSubqueryExpressions: Seq[Expression]): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION",
messageParameters = Map("subqueryExpression" ->
unsupportedSubqueryExpressions.map(_.sql).mkString(", ")))
}

def functionCannotProcessInputError(
unbound: UnboundFunction,
arguments: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4378,6 +4378,24 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION =
buildConf("spark.sql.optimizer.decorrelatePredicateSubqueriesInJoinPredicate.enabled")
.internal()
.doc("Decorrelate predicate (in and exists) subqueries with correlated references in join " +
"predicates.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION =
andylam-db marked this conversation as resolved.
Show resolved Hide resolved
buildConf("spark.sql.optimizer.optimizeUncorrelatedInSubqueriesInJoinCondition.enabled")
.internal()
.doc("When true, optimize uncorrelated IN subqueries in join predicates by rewriting them " +
"to joins.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)
andylam-db marked this conversation as resolved.
Show resolved Hide resolved

val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation")
.internal()
.doc("If true, the old bogus percentile_disc calculation is used. The old calculation " +
Expand Down
Loading