Skip to content

Commit

Permalink
[SPARK-24497][SQL] add tests, fix nested WITH, fix Exchange reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Jan 22, 2019
1 parent f5feb63 commit 24af7b2
Show file tree
Hide file tree
Showing 7 changed files with 1,857 additions and 208 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class Analyzer(
object ResolveRecursiveReferneces extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val recursiveTables = plan.collect {
case rt @ RecursiveTable(name, _, _) if rt.anchorResolved => name -> rt
case rt @ RecursiveTable(name, _, _, _) if rt.anchorResolved => name -> rt
}.toMap

plan.resolveOperatorsUp {
Expand All @@ -231,123 +231,139 @@ class Analyzer(
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
case (resolved, (name, relation)) =>
val recursiveTableName = if (allowRecursion) Some(name) else None
resolved :+
name -> executeSameContext(substituteCTE(relation, resolved, recursiveTableName))
}, None)
val (substitutedPlan, recursiveReferenceFound) =
substituteCTE(relation, resolved, recursiveTableName)
val analyzedPlan = executeSameContext(substitutedPlan)
resolved :+ name -> (
if (recursiveReferenceFound) {
insertRecursiveTable(analyzedPlan, recursiveTableName.get)
} else {
analyzedPlan
})
}, None)._1
case other => other
}

def substituteCTE(
plan: LogicalPlan,
cteRelations: Seq[(String, LogicalPlan)],
recursiveTableName: Option[String]): LogicalPlan = {
def substitute(
plan: LogicalPlan,
inSubQuery: Boolean = false): (LogicalPlan, Boolean) = {
var recursiveReferenceFound = false

val newPlan = plan resolveOperatorsDown {
case u: UnresolvedRelation =>
val table = u.tableIdentifier.table

val recursiveReference = recursiveTableName.find(resolver(_, table)).map { name =>
if (inSubQuery) {
throw new AnalysisException(
s"Recursive reference ${name} can't be used in a subquery")
}

recursiveReferenceFound = true
recursiveTableName: Option[String]): (LogicalPlan, Boolean) = {
var recursiveReferenceFound = false

UnresolvedRecursiveReference(name)
}
val newPlan = plan resolveOperatorsDown {
case u: UnresolvedRelation =>
val table = u.tableIdentifier.table

recursiveReference
.orElse(cteRelations.find(x => resolver(x._1, table)).map(_._2))
.getOrElse(u)
val recursiveReference = recursiveTableName.find(resolver(_, table)).map { name =>
recursiveReferenceFound = true

case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
case e: SubqueryExpression => e.withNewPlan(substitute(e.plan, true)._1)
}
}
UnresolvedRecursiveReference(name)
}

(newPlan, recursiveReferenceFound)
recursiveReference
.orElse(cteRelations.find(x => resolver(x._1, table)).map(_._2))
.getOrElse(u)
case w @ With(_, cteRelations, _) =>
w.copy(cteRelations = cteRelations.map {
case (name, sa @ SubqueryAlias(_, plan)) =>
val (substitutedPlan, recursiveReferenceFoundInCTE) =
substituteCTE(plan, Seq.empty, recursiveTableName)
recursiveReferenceFound |= recursiveReferenceFoundInCTE
(name, sa.copy(child = substitutedPlan))
})
case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
case e: SubqueryExpression =>
val (substitutedPlan, recursiveReferenceFoundInSubQuery) =
substituteCTE(e.plan, cteRelations, recursiveTableName)

recursiveReferenceFound |= recursiveReferenceFoundInSubQuery
e.withNewPlan(substitutedPlan)
}
}

(newPlan, recursiveReferenceFound)
}

def insertRecursiveTable(plan: LogicalPlan, recursiveTableName: String): LogicalPlan =
plan match {
case SubqueryAlias(name, u: Union) if recursiveTableName.contains(name.identifier) =>
case sa @ SubqueryAlias(name, u: Union) if name.identifier == recursiveTableName =>
def combineUnions(union: Union): Seq[LogicalPlan] = union.children.flatMap {
case u: Union => combineUnions(u)
case o => Seq(o)
}

val substitutedTerms = combineUnions(u).map(substitute(_))
val (anchorTerms, recursiveTerms) = substitutedTerms.partition(!_._2)
val combinedTerms = combineUnions(u)
val (anchorTerms, recursiveTerms) = combinedTerms.partition(!_.collectFirst {
case UnresolvedRecursiveReference(name) if name == recursiveTableName => true
}.isDefined)

if (!recursiveTerms.isEmpty) {
if (anchorTerms.isEmpty) {
throw new AnalysisException("There should be at least 1 anchor term defined in the " +
s"recursive query ${recursiveTableName.get}")
s"recursive query ${recursiveTableName}")
}

val recursiveTermPlans = recursiveTerms.map(_._1)

def traversePlanAndCheck(
plan: LogicalPlan,
isRecursiveReferenceAllowed: Boolean = true): Boolean = plan match {
case UnresolvedRecursiveReference(name) if recursiveTableName.contains(name) =>
isRecursiveReferenceAllowed: Boolean = true): Int = plan match {
case UnresolvedRecursiveReference(name) if name == recursiveTableName =>
if (!isRecursiveReferenceAllowed) {
throw new AnalysisException(s"Recursive reference ${recursiveTableName.get} " +
throw new AnalysisException(s"Recursive reference ${recursiveTableName} " +
"cannot be used here. This can be caused by using it in a different join " +
"than inner or left outer or right outer, using it on inner side of an " +
"outer join or using it in an aggregate or with a distinct statement")
"outer join, using it with aggregate or distinct, using it in a subquery " +
"or using it multiple times in a recursive term.")
}
true
1
case Join(left, right, Inner, _, _) =>
val l = traversePlanAndCheck(left, isRecursiveReferenceAllowed)
val r = traversePlanAndCheck(right, isRecursiveReferenceAllowed)
if (l && r) {
throw new AnalysisException(s"Recursive reference ${recursiveTableName.get} " +
"cannot be used on both sides of an inner join")
}
l || r
traversePlanAndCheck(left, isRecursiveReferenceAllowed) +
traversePlanAndCheck(right, isRecursiveReferenceAllowed)
case Join(left, right, LeftOuter, _, _) =>
traversePlanAndCheck(left, isRecursiveReferenceAllowed) ||
traversePlanAndCheck(left, isRecursiveReferenceAllowed) +
traversePlanAndCheck(right, false)
case Join(left, right, RightOuter, _, _) =>
traversePlanAndCheck(left, false) ||
traversePlanAndCheck(left, false) +
traversePlanAndCheck(right, isRecursiveReferenceAllowed)
case Join(left, right, _, _, _) =>
traversePlanAndCheck(left, false) || traversePlanAndCheck(right, false)
traversePlanAndCheck(left, false) +
traversePlanAndCheck(right, false)
case Aggregate(_, _, child) => traversePlanAndCheck(child, false)
case Distinct(child) => traversePlanAndCheck(child, false)
case o =>
o.children.map(traversePlanAndCheck(_, isRecursiveReferenceAllowed)).contains(true)
o transformExpressions {
case se: SubqueryExpression =>
traversePlanAndCheck(se.plan, false)
se
}
o.children
.map(traversePlanAndCheck(_, isRecursiveReferenceAllowed))
.foldLeft(0)(_ + _)
}

recursiveTermPlans.foreach(traversePlanAndCheck(_))
recursiveTerms.foreach { recursiveTerm =>
if (traversePlanAndCheck(recursiveTerm) > 1) {
throw new AnalysisException(s"Recursive reference ${recursiveTableName} cannot " +
"be used multiple times in a recursive term")
}
}

RecursiveTable(
recursiveTableName.get,
SubqueryAlias(name, Union(anchorTerms.map(_._1))),
Union(recursiveTermPlans))
recursiveTableName,
sa.copy(child = Union(anchorTerms)),
Union(recursiveTerms),
None)
} else {
SubqueryAlias(name, Union(substitutedTerms.map(_._1)))
SubqueryAlias(recursiveTableName, Union(combinedTerms))
}

case _ =>
val (substitutedPlan, recursiveReferenceFound) = substitute(plan)

if (recursiveReferenceFound) {
throw new AnalysisException(s"Recursive query ${recursiveTableName.get} should " +
"contain UNION ALL statements only")
}

substitutedPlan
throw new AnalysisException(s"Recursive query ${recursiveTableName} should contain " +
"UNION ALL statements only. This can also be caused by ORDER BY or LIMIT keywords " +
"used on result of UNION ALL.")
}
}
}

/**
* Substitute child plan with WindowSpecDefinitions.
Expand Down Expand Up @@ -1656,7 +1672,7 @@ class Analyzer(
case RecursiveReference(name, _) =>
throw new AnalysisException(s"Recursive reference ${name} can't be used in an " +
"aggregate")
case RecursiveTable(_, _, recursiveTerm) =>
case RecursiveTable(_, _, recursiveTerm, _) =>
case o => o.children.map(traversePlanAndCheck)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode {
case class RecursiveTable(
name: String,
anchorTerm: LogicalPlan,
recursiveTerm: LogicalPlan) extends LogicalPlan {
recursiveTerm: LogicalPlan,
limit: Option[Long]) extends LogicalPlan {
override def children: Seq[LogicalPlan] = Seq(anchorTerm, recursiveTerm)

override def output: Seq[Attribute] = anchorTerm.output.map(_.withNullability(true))
Expand Down Expand Up @@ -553,7 +554,8 @@ case class With(

override def simpleString(maxFields: Int): String = {
val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields)
s"CTE $cteAliases"
val recursive = if (allowRecursion) " recursive" else ""
s"CTE$recursive $cteAliases"
}

override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.ProjectExec(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.FilterExec(condition, planLater(child)) :: Nil
case logical.RecursiveTable(name, anchorTerm, recursiveTerm) =>
case logical.RecursiveTable(name, anchorTerm, recursiveTerm, limit) =>
execution.RecursiveTableExec(
name, planLater(anchorTerm), planLater(recursiveTerm)) :: Nil
name, planLater(anchorTerm), planLater(recursiveTerm), limit) :: Nil
case logical.RecursiveReference(name, output) =>
execution.RecursiveReferenceExec(name, output) :: Nil
case f: logical.TypedFilter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.LongType
Expand Down Expand Up @@ -232,7 +232,8 @@ case class FilterExec(condition: Expression, child: SparkPlan)
case class RecursiveTableExec(
name: String,
anchorTerm: SparkPlan,
recursiveTerm: SparkPlan) extends SparkPlan {
recursiveTerm: SparkPlan,
limit: Option[Long]) extends SparkPlan { // TODO: how to implement limit?
override def children: Seq[SparkPlan] = Seq(anchorTerm, recursiveTerm)

override def output: Seq[Attribute] = anchorTerm.output
Expand All @@ -243,9 +244,10 @@ case class RecursiveTableExec(
var temp = anchorTerm.execute().map(_.copy()).cache()
var tempCount = temp.count()
var result = temp
var sumCount = tempCount
var level = 0
val levelLimit = conf.recursionLevelLimit
do {
while ((level == 0 || tempCount > 0) && limit.map(_ < sumCount).getOrElse(true)) {
if (level > levelLimit) {
throw new SparkException("Recursion level limit reached but query hasn't exhausted, try " +
s"increasing ${SQLConf.RECURSION_LEVEL_LIMIT.key}")
Expand All @@ -261,13 +263,17 @@ case class RecursiveTableExec(
if (level > 0) {
newRecursiveTerm.reset()
}
newRecursiveTerm.foreach {

def updateRecursiveTables(plan: SparkPlan): Unit = plan.foreach {
_ match {
case rr: RecursiveReferenceExec if rr.name == name => rr.recursiveTable = temp
case ReusedExchangeExec(_, child) => updateRecursiveTables(child)
case _ =>
}
}

updateRecursiveTables(newRecursiveTerm)

val newTemp = newRecursiveTerm.execute().map(_.copy()).cache()
tempCount = newTemp.count()
temp.unpersist()
Expand All @@ -276,7 +282,7 @@ case class RecursiveTableExec(
result = result.union(temp)

level = level + 1
} while (tempCount > 0)
}

result
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -56,6 +56,10 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
child.execute()
}

override def doReset(): Unit = {
child.reset()
}

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
child.executeBroadcast()
}
Expand Down Expand Up @@ -90,11 +94,22 @@ case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {
return plan
}
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
// TODO: document recursion related changes
val allExchanges = mutable.Stack[mutable.HashMap[StructType, ArrayBuffer[Exchange]]]()
allExchanges.push(mutable.HashMap[StructType, ArrayBuffer[Exchange]]())
val recursiveTables = mutable.Set.empty[String]
plan.transformUp {
case rr @ RecursiveReferenceExec(name, _) if !recursiveTables.contains(name) =>
allExchanges.push(mutable.HashMap[StructType, ArrayBuffer[Exchange]]())
recursiveTables += name
rr
case rt @ RecursiveTableExec(name, _, _, _) =>
allExchanges.pop()
recursiveTables -= name
rt
case exchange: Exchange =>
// the exchanges that have same results usually also have same schemas (same column names).
val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
val sameSchema = allExchanges.top.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
val samePlan = sameSchema.find { e =>
exchange.sameResult(e)
}
Expand Down
Loading

0 comments on commit 24af7b2

Please sign in to comment.