Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SPARK-24497][SQL] Support recursive SQL query #29210

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ Below is a list of all the keywords in Spark SQL.
|RECOVER|non-reserved|non-reserved|non-reserved|
|REDUCE|non-reserved|non-reserved|non-reserved|
|REFERENCES|reserved|non-reserved|reserved|
|RECURSIVE|reserved|non-reserved|reserved|
|REFRESH|non-reserved|non-reserved|non-reserved|
|REGEXP|non-reserved|non-reserved|not a keyword|
|RENAME|non-reserved|non-reserved|non-reserved|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ describeColName
;

ctes
: WITH namedQuery (',' namedQuery)*
: WITH RECURSIVE? namedQuery (',' namedQuery)*
;

namedQuery
Expand Down Expand Up @@ -1393,6 +1393,7 @@ nonReserved
| RECORDREADER
| RECORDWRITER
| RECOVER
| RECURSIVE
| REDUCE
| REFERENCES
| REFRESH
Expand Down Expand Up @@ -1653,6 +1654,7 @@ RANGE: 'RANGE';
RECORDREADER: 'RECORDREADER';
RECORDWRITER: 'RECORDWRITER';
RECOVER: 'RECOVER';
RECURSIVE: 'RECURSIVE';
REDUCE: 'REDUCE';
REFERENCES: 'REFERENCES';
REFRESH: 'REFRESH';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class Analyzer(
ResolveRelations ::
ResolveTables ::
ResolveReferences ::
ResolveRecursiveReferences ::
ResolveCreateNamedStruct ::
ResolveDeserializer ::
ResolveNewInstance ::
Expand Down Expand Up @@ -1703,6 +1704,23 @@ class Analyzer(
}
}

/**
* This rule resolve [[RecursiveReference]]s when the anchor term of the corresponding
* [[RecursiveRelation]] is resolved (ie. we know the output of the recursive relation).
*/
object ResolveRecursiveReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case rr @ RecursiveRelation(cteName, anchorTerm, recursiveTerm)
if anchorTerm.resolved && !recursiveTerm.resolved =>

val newRecursiveTerm = recursiveTerm.transform {
case UnresolvedRecursiveReference(name, accumulated) if name == cteName =>
RecursiveReference(name, anchorTerm.output.map(_.newInstance()), accumulated)
}
rr.copy(recursiveTerm = newRecursiveTerm)
}
}

/**
* In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
* clauses. This rule is to convert ordinal positions to the corresponding expressions in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Except, RecursiveRelation, Union}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, With}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -59,7 +60,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
startOfQuery: Boolean = true): Unit = {
val resolver = SQLConf.get.resolver
plan match {
case With(child, relations) =>
case With(child, relations, _) =>
val newNames = mutable.ArrayBuffer.empty[String]
newNames ++= outerCTERelationNames
relations.foreach {
Expand All @@ -86,7 +87,7 @@ object CTESubstitution extends Rule[LogicalPlan] {

private def legacyTraverseAndSubstituteCTE(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUp {
case With(child, relations) =>
case With(child, relations, _) =>
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = true)
substituteCTE(child, resolvedCTERelations)
}
Expand Down Expand Up @@ -135,8 +136,8 @@ object CTESubstitution extends Rule[LogicalPlan] {
*/
private def traverseAndSubstituteCTE(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUp {
case With(child: LogicalPlan, relations) =>
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = false)
case With(child: LogicalPlan, relations, allowRecursion) =>
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = false, allowRecursion)
substituteCTE(child, resolvedCTERelations)

case other =>
Expand All @@ -148,7 +149,8 @@ object CTESubstitution extends Rule[LogicalPlan] {

private def resolveCTERelations(
relations: Seq[(String, SubqueryAlias)],
isLegacy: Boolean): Seq[(String, LogicalPlan)] = {
isLegacy: Boolean,
allowRecursion: Boolean = false): Seq[(String, LogicalPlan)] = {
val resolvedCTERelations = new mutable.ArrayBuffer[(String, LogicalPlan)](relations.size)
for ((name, relation) <- relations) {
val innerCTEResolved = if (isLegacy) {
Expand All @@ -161,8 +163,13 @@ object CTESubstitution extends Rule[LogicalPlan] {
// substitute CTE defined in `relation` first.
traverseAndSubstituteCTE(relation)
}
val recursionHandled = if (allowRecursion) {
handleRecursion(innerCTEResolved, name)
} else {
innerCTEResolved
}
// CTE definition can reference a previous one
resolvedCTERelations += (name -> substituteCTE(innerCTEResolved, resolvedCTERelations.toSeq))
resolvedCTERelations += (name -> substituteCTE(recursionHandled, resolvedCTERelations.toSeq))
}
resolvedCTERelations.toSeq
}
Expand All @@ -180,4 +187,114 @@ object CTESubstitution extends Rule[LogicalPlan] {
case e: SubqueryExpression => e.withNewPlan(substituteCTE(e.plan, cteRelations))
}
}

/**
* If recursion is allowed, recursion handling starts with inserting unresolved self-references
* ([[UnresolvedRecursiveReference]]) to places where a reference to the CTE definition itself is
* found.
* If there is a self-reference then we need to check if structure of the query satisfies the SQL
* recursion rules and insert a [[RecursiveRelation]] finally.
*/
private def handleRecursion(plan: LogicalPlan, cteName: String) = {
// check if there is any reference to the CTE and if there is then treat the CTE as recursive
val (recursiveReferencesPlan, recursiveReferenceCount) =
insertRecursiveReferences(plan, cteName)
if (recursiveReferenceCount > 0) {
// if there is a reference then the CTE needs to follow one of these structures
recursiveReferencesPlan match {
case SubqueryAlias(_, u: Union) =>
insertRecursiveRelation(cteName, Seq.empty, false, u)
case SubqueryAlias(_, Distinct(u: Union)) =>
insertRecursiveRelation(cteName, Seq.empty, true, u)
case SubqueryAlias(_, UnresolvedSubqueryColumnAliases(columnNames, u: Union)) =>
insertRecursiveRelation(cteName, columnNames, false, u)
case SubqueryAlias(_, UnresolvedSubqueryColumnAliases(columnNames, Distinct(u: Union))) =>
insertRecursiveRelation(cteName, columnNames, true, u)
case _ =>
throw new AnalysisException(s"Recursive query $cteName should contain UNION or UNION " +
"ALL statements only. This error can also be caused by ORDER BY or LIMIT keywords " +
"used on result of UNION or UNION ALL.")
}
} else {
plan
}
}

/**
* If we encounter a relation that matches the recursive CTE then the relation is replaced to an
* [[UnresolvedRecursiveReference]]. The replacement process also checks possible references in
* subqueries and reports them as errors.
*/
private def insertRecursiveReferences(plan: LogicalPlan, cteName: String): (LogicalPlan, Int) = {
val resolver = SQLConf.get.resolver

var recursiveReferenceCount = 0
val newPlan = plan resolveOperators {
case UnresolvedRelation(Seq(table)) if (resolver(cteName, table)) =>
recursiveReferenceCount += 1
UnresolvedRecursiveReference(cteName, false)

case other =>
other.subqueries.foreach(checkAndTraverse(_, {
case UnresolvedRelation(Seq(table)) if resolver(cteName, table) =>
throw new AnalysisException(s"Recursive query $cteName should not contain recursive " +
"references in its subquery.")
case _ => true
}))
other
}

(newPlan, recursiveReferenceCount)
}

private def insertRecursiveRelation(
cteName: String,
columnNames: Seq[String],
distinct: Boolean,
union: Union) = {
if (union.children.size != 2) {
throw new AnalysisException(s"Recursive query ${cteName} should contain one anchor term " +
"and one recursive term connected with UNION or UNION ALL.")
}

val anchorTerm :: recursiveTerm :: Nil = union.children

// The anchor term shouldn't contain a recursive reference that matches the name of the CTE,
// except if it is nested under an other RecursiveRelation with the same name.
checkAndTraverse(anchorTerm, {
case UnresolvedRecursiveReference(name, _) if name == cteName =>
throw new AnalysisException(s"Recursive query $cteName should not contain recursive " +
"references in its anchor (first) term.")
case RecursiveRelation(name, _, _) if name == cteName => false
case _ => true
})

// The anchor term has a special role, its output column are aliased if required.
val aliasedAnchorTerm = SubqueryAlias(cteName,
if (columnNames.nonEmpty) {
UnresolvedSubqueryColumnAliases(columnNames, anchorTerm)
} else {
anchorTerm
}
)

// If UNION combinator is used between the terms we extend the anchor with a DISTINCT and the
// recursive term with an EXCEPT clause and a reference to the so far accumulated result.
if (distinct) {
RecursiveRelation(cteName, Distinct(aliasedAnchorTerm),
Except(recursiveTerm, UnresolvedRecursiveReference(cteName, true), false))
} else {
RecursiveRelation(cteName, aliasedAnchorTerm, recursiveTerm)
}
}

/**
* Taverses the plan including subqueries and run the check while it returns true.
*/
private def checkAndTraverse(plan: LogicalPlan, check: LogicalPlan => Boolean): Unit = {
if (check(plan)) {
plan.children.foreach(checkAndTraverse(_, check))
plan.subqueries.foreach(checkAndTraverse(_, check))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ trait CheckAnalysis extends PredicateHelper {
case _ => // Analysis successful!
}
}
checkRecursion(plan)
checkCollectedMetrics(plan)
extendedCheckRules.foreach(_(plan))
plan.foreachUp {
Expand All @@ -671,6 +672,62 @@ trait CheckAnalysis extends PredicateHelper {
plan.setAnalyzed()
}

/**
* Recursion according to SQL standard comes with several limitations due to the fact that only
* those operations are allowed where the new set of rows can be computed from the result of the
* previous iteration. This implies that a recursive reference can't be used in some kinds of
* joins and aggregations.
* A further constraint is that a recursive term can contain one recursive reference only (except
* for using it on different sides of a UNION).
*
* This rule checks that these restrictions are not violated and returns the original plan.
*/
private def checkRecursion(
plan: LogicalPlan,
allowedRecursiveReferencesAndCounts: mutable.Map[String, Int] = mutable.Map.empty): Unit = {
plan match {
case RecursiveRelation(name, anchorTerm, recursiveTerm) =>
if (allowedRecursiveReferencesAndCounts.contains(name)) {
throw new AnalysisException(s"Recursive CTE definition $name is already in use.")
}
checkRecursion(anchorTerm, allowedRecursiveReferencesAndCounts)
checkRecursion(recursiveTerm, allowedRecursiveReferencesAndCounts += name -> 0)
allowedRecursiveReferencesAndCounts -= name
case RecursiveReference(name, _, false) =>
if (!allowedRecursiveReferencesAndCounts.contains(name)) {
throw new AnalysisException(s"Recursive reference $name cannot be used here. This can " +
"be caused by using it on inner side of an outer join, using it with aggregate in a " +
"subquery or using it multiple times in a recursive term (except for using it on " +
"different sides of an UNION ALL).")
}
if (allowedRecursiveReferencesAndCounts(name) > 0) {
throw new AnalysisException(s"Recursive reference $name cannot be used multiple times " +
"in a recursive term.")
}

allowedRecursiveReferencesAndCounts +=
name -> (allowedRecursiveReferencesAndCounts(name) + 1)
case Join(left, right, Inner, _, _) =>
checkRecursion(left, allowedRecursiveReferencesAndCounts)
checkRecursion(right, allowedRecursiveReferencesAndCounts)
case Join(left, right, LeftOuter, _, _) =>
checkRecursion(left, allowedRecursiveReferencesAndCounts)
checkRecursion(right, mutable.Map.empty)
case Join(left, right, RightOuter, _, _) =>
checkRecursion(left, mutable.Map.empty)
checkRecursion(right, allowedRecursiveReferencesAndCounts)
case Join(left, right, _, _, _) =>
checkRecursion(left, mutable.Map.empty)
checkRecursion(right, mutable.Map.empty)
case Aggregate(_, _, child) => checkRecursion(child, mutable.Map.empty)
case Union(children, _, _) =>
children.foreach(checkRecursion(_,
mutable.Map(allowedRecursiveReferencesAndCounts.keys.map(name => name -> 0).toSeq: _*)))
case o =>
o.children.foreach(checkRecursion(_, allowedRecursiveReferencesAndCounts))
}
}

/**
* Validates subquery expressions in the plan. Upon failure, returns an user facing error.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,9 @@ case class UnresolvedHaving(
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
}

case class UnresolvedRecursiveReference(cteName: String, accumulated: Boolean) extends LeafNode {
override def output: Seq[Attribute] = Nil

override lazy val resolved = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,9 @@ object ColumnPruning extends Rule[LogicalPlan] {

case NestedColumnAliasing(p) => p

// Don't prune columns of RecursiveTable
case p @ Project(_, _: RecursiveRelation) => p

// for all other logical plans that inherits the output from it's children
// Project over project is handled by the first case, skip it here.
case p @ Project(_, child) if !child.isInstanceOf[Project] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
s"CTE definition can't have duplicate names: ${duplicates.mkString("'", "', '", "'")}.",
ctx)
}
With(plan, ctes.toSeq)
With(plan, ctes.toSeq, ctx.RECURSIVE() != null)
}

/**
Expand Down
Loading