-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-14785][SQL] Support correlated scalar subqueries #12822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1827075
d189424
28e0878
84fff35
d9f1bc8
0ae7dee
831eaa8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer | |
|
|
||
| import scala.annotation.tailrec | ||
| import scala.collection.immutable.HashSet | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} | ||
| import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} | ||
|
|
@@ -100,6 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) | |
| EliminateSorts, | ||
| SimplifyCasts, | ||
| SimplifyCaseConversionExpressions, | ||
| RewriteCorrelatedScalarSubquery, | ||
| EliminateSerialization) :: | ||
| Batch("Decimal Optimizations", fixedPoint, | ||
| DecimalAggregates) :: | ||
|
|
@@ -1081,7 +1083,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { | |
| assert(input.size >= 2) | ||
| if (input.size == 2) { | ||
| val (joinConditions, others) = conditions.partition( | ||
| e => !PredicateSubquery.hasPredicateSubquery(e)) | ||
| e => !SubqueryExpression.hasCorrelatedSubquery(e)) | ||
| val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And)) | ||
| if (others.nonEmpty) { | ||
| Filter(others.reduceLeft(And), join) | ||
|
|
@@ -1101,7 +1103,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { | |
|
|
||
| val joinedRefs = left.outputSet ++ right.outputSet | ||
| val (joinConditions, others) = conditions.partition( | ||
| e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e)) | ||
| e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) | ||
| val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) | ||
|
|
||
| // should not have reference to same logical plan | ||
|
|
@@ -1134,7 +1136,7 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { | |
| * Returns whether the expression returns null or false when all inputs are nulls. | ||
| */ | ||
| private def canFilterOutNull(e: Expression): Boolean = { | ||
| if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false | ||
| if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false | ||
| val attributes = e.references.toSeq | ||
| val emptyRow = new GenericInternalRow(attributes.length) | ||
| val v = BindReferences.bindReference(e, attributes).eval(emptyRow) | ||
|
|
@@ -1203,7 +1205,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { | |
| case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => | ||
| val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = | ||
| split(splitConjunctivePredicates(filterCondition), left, right) | ||
|
|
||
| joinType match { | ||
| case Inner => | ||
| // push down the single side `where` condition into respective sides | ||
|
|
@@ -1212,7 +1213,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { | |
| val newRight = rightFilterConditions. | ||
| reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) | ||
| val (newJoinConditions, others) = | ||
| commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e)) | ||
| commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e)) | ||
| val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) | ||
|
|
||
| val join = Join(newLeft, newRight, Inner, newJoinCond) | ||
|
|
@@ -1573,3 +1574,74 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { | |
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. | ||
| */ | ||
| object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | ||
| /** | ||
| * Extract all correlated scalar subqueries from an expression. The subqueries are collected using | ||
| * the given collector. The expression is rewritten and returned. | ||
| */ | ||
| private def extractCorrelatedScalarSubqueries[E <: Expression]( | ||
| expression: E, | ||
| subqueries: ArrayBuffer[ScalarSubquery]): E = { | ||
| val newExpression = expression transform { | ||
| case s: ScalarSubquery if s.children.nonEmpty => | ||
| subqueries += s | ||
| s.query.output.head | ||
| } | ||
| newExpression.asInstanceOf[E] | ||
| } | ||
|
|
||
| /** | ||
| * Construct a new child plan by left joining the given subqueries to a base plan. | ||
| */ | ||
| private def constructLeftJoins( | ||
| child: LogicalPlan, | ||
| subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { | ||
| subqueries.foldLeft(child) { | ||
| case (currentChild, ScalarSubquery(query, conditions, _)) => | ||
| Project( | ||
| currentChild.output :+ query.output.head, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we know that the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The query contains both the column we are interested in and the join columns. We don't want those, so we remove them
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, thanks |
||
| Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar | ||
| * subqueries. | ||
| */ | ||
| def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
| case a @ Aggregate(grouping, expressions, child) => | ||
| val subqueries = ArrayBuffer.empty[ScalarSubquery] | ||
| val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) | ||
| if (subqueries.nonEmpty) { | ||
| // We currently only allow correlated subqueries in an aggregate if they are part of the | ||
| // grouping expressions. As a result we need to replace all the scalar subqueries in the | ||
| // grouping expressions by their result. | ||
| val newGrouping = grouping.map { e => | ||
| subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e) | ||
| } | ||
| Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) | ||
| } else { | ||
| a | ||
| } | ||
| case p @ Project(expressions, child) => | ||
| val subqueries = ArrayBuffer.empty[ScalarSubquery] | ||
| val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) | ||
| if (subqueries.nonEmpty) { | ||
| Project(newExpressions, constructLeftJoins(child, subqueries)) | ||
| } else { | ||
| p | ||
| } | ||
| case f @ Filter(condition, child) => | ||
| val subqueries = ArrayBuffer.empty[ScalarSubquery] | ||
| val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) | ||
| if (subqueries.nonEmpty) { | ||
| Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) | ||
| } else { | ||
| f | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,7 +109,7 @@ case class Filter(condition: Expression, child: LogicalPlan) | |
|
|
||
| override protected def validConstraints: Set[Expression] = { | ||
| val predicates = splitConjunctivePredicates(condition) | ||
| .filterNot(PredicateSubquery.hasPredicateSubquery) | ||
| .filterNot(SubqueryExpression.hasCorrelatedSubquery) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @davies I changed the filter to prevent any correlated subquery from being propagated.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, I missed this one from latest changes. |
||
| child.constraints.union(predicates.toSet) | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it have an Filter on top of Aggregate (HAVING clause)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I'll add it.