Skip to content

Commit

Permalink
[SPARK-24497][SQL] fix review findings 3
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Jan 16, 2019
1 parent 8f9a673 commit f5feb63
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ class Analyzer(

plan.resolveOperatorsUp {
case UnresolvedRecursiveReference(name) if recursiveTables.contains(name) =>
// creating new instance of attributes here makes possible to avoid complex attribute
// handling in FoldablePropagation
RecursiveReference(name, recursiveTables(name).output.map(_.newInstance()))
case other => other
}
Expand Down Expand Up @@ -242,14 +244,7 @@ class Analyzer(
def substitute(
plan: LogicalPlan,
inSubQuery: Boolean = false): (LogicalPlan, Boolean) = {
val references = mutable.Set.empty[UnresolvedRecursiveReference]

def newReference(recursiveTableName: String) = {
val recursiveReference = UnresolvedRecursiveReference(recursiveTableName)
references += recursiveReference

recursiveReference
}
var recursiveReferenceFound = false

val newPlan = plan resolveOperatorsDown {
case u: UnresolvedRelation =>
Expand All @@ -261,7 +256,9 @@ class Analyzer(
s"Recursive reference ${name} can't be used in a subquery")
}

newReference(name)
recursiveReferenceFound = true

UnresolvedRecursiveReference(name)
}

recursiveReference
Expand All @@ -275,11 +272,11 @@ class Analyzer(
}
}

(newPlan, !references.isEmpty)
(newPlan, recursiveReferenceFound)
}

plan match {
case SubqueryAlias(name, u: Union) if recursiveTableName.isDefined =>
case SubqueryAlias(name, u: Union) if recursiveTableName.contains(name.identifier) =>
def combineUnions(union: Union): Seq[LogicalPlan] = union.children.flatMap {
case u: Union => combineUnions(u)
case o => Seq(o)
Expand All @@ -290,26 +287,29 @@ class Analyzer(

if (!recursiveTerms.isEmpty) {
if (anchorTerms.isEmpty) {
throw new AnalysisException("There should be at least 1 anchor term defined in a " +
s"recursive query $name")
throw new AnalysisException("There should be at least 1 anchor term defined in the " +
s"recursive query ${recursiveTableName.get}")
}

val recursiveTermPlans = recursiveTerms.map(_._1)

def traversePlanAndCheck(
plan: LogicalPlan,
isRecursiveReferenceAllowed: Boolean = true): Boolean = plan match {
case UnresolvedRecursiveReference(name) =>
case UnresolvedRecursiveReference(name) if recursiveTableName.contains(name) =>
if (!isRecursiveReferenceAllowed) {
throw new AnalysisException(s"Wrong usage of recursive reference ${name}")
throw new AnalysisException(s"Recursive reference ${recursiveTableName.get} " +
"cannot be used here. This can be caused by using it in a different join " +
"than inner or left outer or right outer, using it on inner side of an " +
"outer join or using it in an aggregate or with a distinct statement")
}
true
case Join(left, right, Inner, _, _) =>
val l = traversePlanAndCheck(left, isRecursiveReferenceAllowed)
val r = traversePlanAndCheck(right, isRecursiveReferenceAllowed)
if (l && r) {
throw new AnalysisException("Recursive reference can't be used in on both " +
"side of an inner join")
throw new AnalysisException(s"Recursive reference ${recursiveTableName.get} " +
"cannot be used on both sides of an inner join")
}
l || r
case Join(left, right, LeftOuter, _, _) =>
Expand All @@ -318,7 +318,7 @@ class Analyzer(
case Join(left, right, RightOuter, _, _) =>
traversePlanAndCheck(left, false) ||
traversePlanAndCheck(right, isRecursiveReferenceAllowed)
case Join(left, right, FullOuter, _, _) =>
case Join(left, right, _, _, _) =>
traversePlanAndCheck(left, false) || traversePlanAndCheck(right, false)
case Aggregate(_, _, child) => traversePlanAndCheck(child, false)
case Distinct(child) => traversePlanAndCheck(child, false)
Expand All @@ -340,8 +340,8 @@ class Analyzer(
val (substitutedPlan, recursiveReferenceFound) = substitute(plan)

if (recursiveReferenceFound) {
throw new AnalysisException("Wrong usage of recursive reference " +
s"${recursiveTableName.get}")
throw new AnalysisException(s"Recursive query ${recursiveTableName.get} should " +
"contain UNION ALL statements only")
}

substitutedPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
}

/**
* This node defines a table that contains one ore more [[RecursiveReference]]s as child nodes
* referring to this table. It can be used to define a recursive CTE query and contains an anchor
* and a recursive term as children. The result of the anchor and the repeatedly executed recursive
* term are combined to form the final result.
*
* @param name name of the table
* @param anchorTerm this child is used for initializing the query
* @param recursiveTerm this child is used for extending the set of results with new rows based on
* the results of the previous iteration (or the anchor in the first iteration)
*/
case class RecursiveTable(
name: String,
anchorTerm: LogicalPlan,
Expand Down Expand Up @@ -76,6 +87,12 @@ case class RecursiveTable(
lazy val anchorResolved = anchorTerm.resolved
}

/**
* A This node means a reference to a recursive table in CTE definitions.
*
* @param name the name of the table it references to
* @param output the attributes of the recursive table
*/
case class RecursiveReference(name: String, output: Seq[Attribute]) extends LeafNode {
override lazy val resolved = output.forall(_.resolved)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1626,9 +1626,10 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val RECURSION_LEVEL_LIMIT = buildConf("spark.sql.recursion.level.limit")
val RECURSION_LEVEL_LIMIT = buildConf("spark.sql.cte.recursion.level.limit")
.internal()
.doc("Maximum level of recursion")
.doc("Maximum level of recursion that is allowed wile executing a recursive CTE definition." +
"If a query does not get exhausted before reaching this limit it fails.")
.intConf
.createWithDefault(100)
}
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/test/resources/sql-tests/inputs/recursion.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ WITH RECURSIVE r AS (
)
SELECT * FROM r ORDER BY level;

-- unlimited recursion fails at spark.sql.recursion.level.limit level
-- unlimited recursion fails at spark.sql.cte.recursion.level.limits level
WITH RECURSIVE r AS (
VALUES (0, 'A') AS T(level, data)
UNION ALL
Expand Down
16 changes: 8 additions & 8 deletions sql/core/src/test/resources/sql-tests/results/recursion.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ SELECT * FROM r ORDER BY level
struct<>
-- !query 2 output
org.apache.spark.SparkException
Recursion level limit reached but query hasn't exhausted, try increasing spark.sql.recursion.level.limit
Recursion level limit reached but query hasn't exhausted, try increasing spark.sql.cte.recursion.level.limit


-- !query 3
Expand Down Expand Up @@ -184,7 +184,7 @@ SELECT * FROM r
struct<>
-- !query 7 output
org.apache.spark.sql.AnalysisException
Wrong usage of recursive reference r;
Recursive query r should contain UNION ALL statements only;


-- !query 8
Expand All @@ -198,7 +198,7 @@ SELECT * FROM r
struct<>
-- !query 8 output
org.apache.spark.sql.AnalysisException
There should be at least 1 anchor term defined in a recursive query `r`;
There should be at least 1 anchor term defined in the recursive query r;


-- !query 9
Expand Down Expand Up @@ -228,7 +228,7 @@ SELECT * FROM r
struct<>
-- !query 10 output
org.apache.spark.sql.AnalysisException
Recursive reference can't be used in on both side of an inner join;
Recursive reference r cannot be used on both sides of an inner join;


-- !query 11
Expand All @@ -244,7 +244,7 @@ SELECT * FROM r
struct<>
-- !query 11 output
org.apache.spark.sql.AnalysisException
Wrong usage of recursive reference r;
Recursive reference r cannot be used here. This can be caused by using it in a different join than inner or left outer or right outer, using it on inner side of an outer join or using it in an aggregate or with a distinct statement;


-- !query 12
Expand All @@ -260,7 +260,7 @@ SELECT * FROM r
struct<>
-- !query 12 output
org.apache.spark.sql.AnalysisException
Wrong usage of recursive reference r;
Recursive reference r cannot be used here. This can be caused by using it in a different join than inner or left outer or right outer, using it on inner side of an outer join or using it in an aggregate or with a distinct statement;


-- !query 13
Expand Down Expand Up @@ -297,7 +297,7 @@ SELECT * FROM r
struct<>
-- !query 14 output
org.apache.spark.sql.AnalysisException
Wrong usage of recursive reference r;
Recursive reference r cannot be used here. This can be caused by using it in a different join than inner or left outer or right outer, using it on inner side of an outer join or using it in an aggregate or with a distinct statement;


-- !query 15
Expand Down Expand Up @@ -338,7 +338,7 @@ SELECT * FROM r
struct<>
-- !query 17 output
org.apache.spark.sql.AnalysisException
Wrong usage of recursive reference r;
Recursive reference r cannot be used here. This can be caused by using it in a different join than inner or left outer or right outer, using it on inner side of an outer join or using it in an aggregate or with a distinct statement;


-- !query 18
Expand Down

0 comments on commit f5feb63

Please sign in to comment.