diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f72cc7b89b9b4..71bad44c6dc08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -214,6 +214,8 @@ class Analyzer( plan.resolveOperatorsUp { case UnresolvedRecursiveReference(name) if recursiveTables.contains(name) => + // creating new instance of attributes here makes possible to avoid complex attribute + // handling in FoldablePropagation RecursiveReference(name, recursiveTables(name).output.map(_.newInstance())) case other => other } @@ -242,14 +244,7 @@ class Analyzer( def substitute( plan: LogicalPlan, inSubQuery: Boolean = false): (LogicalPlan, Boolean) = { - val references = mutable.Set.empty[UnresolvedRecursiveReference] - - def newReference(recursiveTableName: String) = { - val recursiveReference = UnresolvedRecursiveReference(recursiveTableName) - references += recursiveReference - - recursiveReference - } + var recursiveReferenceFound = false val newPlan = plan resolveOperatorsDown { case u: UnresolvedRelation => @@ -261,7 +256,9 @@ class Analyzer( s"Recursive reference ${name} can't be used in a subquery") } - newReference(name) + recursiveReferenceFound = true + + UnresolvedRecursiveReference(name) } recursiveReference @@ -275,11 +272,11 @@ class Analyzer( } } - (newPlan, !references.isEmpty) + (newPlan, recursiveReferenceFound) } plan match { - case SubqueryAlias(name, u: Union) if recursiveTableName.isDefined => + case SubqueryAlias(name, u: Union) if recursiveTableName.contains(name.identifier) => def combineUnions(union: Union): Seq[LogicalPlan] = union.children.flatMap { case u: Union => combineUnions(u) case o => Seq(o) @@ -290,8 +287,8 @@ class Analyzer( if (!recursiveTerms.isEmpty) { if (anchorTerms.isEmpty) { - throw new AnalysisException("There should be at least 1 anchor term defined in a " + - s"recursive query $name") + throw new AnalysisException("There should be at least 1 anchor term defined in the " + + s"recursive query ${recursiveTableName.get}") } val recursiveTermPlans = recursiveTerms.map(_._1) @@ -299,17 +296,20 @@ class Analyzer( def traversePlanAndCheck( plan: LogicalPlan, isRecursiveReferenceAllowed: Boolean = true): Boolean = plan match { - case UnresolvedRecursiveReference(name) => + case UnresolvedRecursiveReference(name) if recursiveTableName.contains(name) => if (!isRecursiveReferenceAllowed) { - throw new AnalysisException(s"Wrong usage of recursive reference ${name}") + throw new AnalysisException(s"Recursive reference ${recursiveTableName.get} " + + "cannot be used here. This can be caused by using it in a different join " + + "than inner or left outer or right outer, using it on inner side of an " + + "outer join or using it in an aggregate or with a distinct statement") } true case Join(left, right, Inner, _, _) => val l = traversePlanAndCheck(left, isRecursiveReferenceAllowed) val r = traversePlanAndCheck(right, isRecursiveReferenceAllowed) if (l && r) { - throw new AnalysisException("Recursive reference can't be used in on both " + - "side of an inner join") + throw new AnalysisException(s"Recursive reference ${recursiveTableName.get} " + + "cannot be used on both sides of an inner join") } l || r case Join(left, right, LeftOuter, _, _) => @@ -318,7 +318,7 @@ class Analyzer( case Join(left, right, RightOuter, _, _) => traversePlanAndCheck(left, false) || traversePlanAndCheck(right, isRecursiveReferenceAllowed) - case Join(left, right, FullOuter, _, _) => + case Join(left, right, _, _, _) => traversePlanAndCheck(left, false) || traversePlanAndCheck(right, false) case Aggregate(_, _, child) => traversePlanAndCheck(child, false) case Distinct(child) => traversePlanAndCheck(child, false) @@ -340,8 +340,8 @@ class Analyzer( val (substitutedPlan, recursiveReferenceFound) = substitute(plan) if (recursiveReferenceFound) { - throw new AnalysisException("Wrong usage of recursive reference " + - s"${recursiveTableName.get}") + throw new AnalysisException(s"Recursive query ${recursiveTableName.get} should " + + "contain UNION ALL statements only") } substitutedPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index fdced27c447b6..6ad0b15133065 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -48,6 +48,17 @@ case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output } +/** +* This node defines a table that contains one ore more [[RecursiveReference]]s as child nodes +* referring to this table. It can be used to define a recursive CTE query and contains an anchor +* and a recursive term as children. The result of the anchor and the repeatedly executed recursive +* term are combined to form the final result. +* +* @param name name of the table +* @param anchorTerm this child is used for initializing the query +* @param recursiveTerm this child is used for extending the set of results with new rows based on +* the results of the previous iteration (or the anchor in the first iteration) +*/ case class RecursiveTable( name: String, anchorTerm: LogicalPlan, @@ -76,6 +87,12 @@ case class RecursiveTable( lazy val anchorResolved = anchorTerm.resolved } +/** +* A This node means a reference to a recursive table in CTE definitions. +* +* @param name the name of the table it references to +* @param output the attributes of the recursive table +*/ case class RecursiveReference(name: String, output: Seq[Attribute]) extends LeafNode { override lazy val resolved = output.forall(_.resolved) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b1cd3d4c04c30..e0cc9cf94f0bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1626,9 +1626,10 @@ object SQLConf { .booleanConf .createWithDefault(true) - val RECURSION_LEVEL_LIMIT = buildConf("spark.sql.recursion.level.limit") + val RECURSION_LEVEL_LIMIT = buildConf("spark.sql.cte.recursion.level.limit") .internal() - .doc("Maximum level of recursion") + .doc("Maximum level of recursion that is allowed wile executing a recursive CTE definition." + + "If a query does not get exhausted before reaching this limit it fails.") .intConf .createWithDefault(100) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/recursion.sql b/sql/core/src/test/resources/sql-tests/inputs/recursion.sql index 9fa9c6de698cb..9941e3ae9a303 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/recursion.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/recursion.sql @@ -18,7 +18,7 @@ WITH RECURSIVE r AS ( ) SELECT * FROM r ORDER BY level; --- unlimited recursion fails at spark.sql.recursion.level.limit level +-- unlimited recursion fails at spark.sql.cte.recursion.level.limits level WITH RECURSIVE r AS ( VALUES (0, 'A') AS T(level, data) UNION ALL diff --git a/sql/core/src/test/resources/sql-tests/results/recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/recursion.sql.out index 32e98bc87de7f..1f9df9bb3cc02 100644 --- a/sql/core/src/test/resources/sql-tests/results/recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/recursion.sql.out @@ -50,7 +50,7 @@ SELECT * FROM r ORDER BY level struct<> -- !query 2 output org.apache.spark.SparkException -Recursion level limit reached but query hasn't exhausted, try increasing spark.sql.recursion.level.limit +Recursion level limit reached but query hasn't exhausted, try increasing spark.sql.cte.recursion.level.limit -- !query 3 @@ -184,7 +184,7 @@ SELECT * FROM r struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -Wrong usage of recursive reference r; +Recursive query r should contain UNION ALL statements only; -- !query 8 @@ -198,7 +198,7 @@ SELECT * FROM r struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -There should be at least 1 anchor term defined in a recursive query `r`; +There should be at least 1 anchor term defined in the recursive query r; -- !query 9 @@ -228,7 +228,7 @@ SELECT * FROM r struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -Recursive reference can't be used in on both side of an inner join; +Recursive reference r cannot be used on both sides of an inner join; -- !query 11 @@ -244,7 +244,7 @@ SELECT * FROM r struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Wrong usage of recursive reference r; +Recursive reference r cannot be used here. This can be caused by using it in a different join than inner or left outer or right outer, using it on inner side of an outer join or using it in an aggregate or with a distinct statement; -- !query 12 @@ -260,7 +260,7 @@ SELECT * FROM r struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Wrong usage of recursive reference r; +Recursive reference r cannot be used here. This can be caused by using it in a different join than inner or left outer or right outer, using it on inner side of an outer join or using it in an aggregate or with a distinct statement; -- !query 13 @@ -297,7 +297,7 @@ SELECT * FROM r struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -Wrong usage of recursive reference r; +Recursive reference r cannot be used here. This can be caused by using it in a different join than inner or left outer or right outer, using it on inner side of an outer join or using it in an aggregate or with a distinct statement; -- !query 15 @@ -338,7 +338,7 @@ SELECT * FROM r struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Wrong usage of recursive reference r; +Recursive reference r cannot be used here. This can be caused by using it in a different join than inner or left outer or right outer, using it on inner side of an outer join or using it in an aggregate or with a distinct statement; -- !query 18