From 31c534403d119b431cb6d493d5acc50db2ac73a6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 Mar 2016 20:58:15 +0800 Subject: [PATCH 1/2] transformExpressions should exclude expression that is not inside QueryPlan.expressions --- .../sql/catalyst/analysis/Analyzer.scala | 11 ------ .../spark/sql/catalyst/plans/QueryPlan.scala | 18 ++++++--- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 37 ++++++++++++------- .../spark/sql/catalyst/plans/PlanTest.scala | 19 +++++++--- 5 files changed, 49 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fbbc3ee891c6b..db99c9310c78e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -510,17 +510,6 @@ class Analyzer( ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) - // A special case for Generate, because the output of Generate should not be resolved by - // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. - case g @ Generate(generator, join, outer, qualifier, output, child) - if child.resolved && !generator.resolved => - val newG = resolveExpression(generator, child, throws = true) - if (newG.fastEquals(generator)) { - g - } else { - Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) - } - // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator // should be resolved by their corresponding attributes instead of children's output. case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0e0453b517d92..a630ec7951bfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{TreeNodeRef, TreeNode} import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { @@ -107,8 +107,14 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy */ def missingInput: AttributeSet = references -- inputSet -- producedAttributes + private lazy val expressionRefs = expressions.map(new TreeNodeRef(_)).toSet + + protected def isOneOfExpressions(e: Expression): Boolean = { + expressionRefs.contains(new TreeNodeRef(e)) + } + /** - * Runs [[transform]] with `rule` on all expressions present in this query operator. + * Runs [[transform]] with `rule` on all expressions returned by [[expressions]] of this plan. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. * @@ -137,8 +143,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } def recursiveTransform(arg: Any): AnyRef = arg match { - case e: Expression => transformExpressionDown(e) - case Some(e: Expression) => Some(transformExpressionDown(e)) + case e: Expression if isOneOfExpressions(e) => transformExpressionDown(e) + case Some(e: Expression) if isOneOfExpressions(e) => Some(transformExpressionDown(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) @@ -171,8 +177,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } def recursiveTransform(arg: Any): AnyRef = arg match { - case e: Expression => transformExpressionUp(e) - case Some(e: Expression) => Some(transformExpressionUp(e)) + case e: Expression if isOneOfExpressions(e) => transformExpressionUp(e) + case Some(e: Expression) if isOneOfExpressions(e) => Some(transformExpressionUp(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 31e775d60f950..94153bf870f01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -155,7 +155,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { productIterator.map { // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null + case tn: TreeNode[_] if isOneOfChildren(tn) => null case e: Expression => cleanExpression(e) case s: Option[_] => s.map { case e: Expression => cleanExpression(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 6b7997e903a99..43a4e74b43b63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -81,7 +81,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def children: Seq[BaseType] - lazy val containsChild: Set[TreeNode[_]] = children.toSet + protected lazy val childrenRefs = children.map(new TreeNodeRef(_)).toSet + + protected def isOneOfChildren(node: TreeNode[_]): Boolean = { + childrenRefs.contains(new TreeNodeRef(node)) + } + + protected def isSubsetOfChildren(nodes: Seq[_]): Boolean = nodes.forall { + case tn: TreeNode[_] => isOneOfChildren(tn) + case _ => false + } /** * Faster version of equality which short-circuits when two treeNodes are the same instance. @@ -168,7 +177,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def mapChildren(f: BaseType => BaseType): BaseType = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if containsChild(arg) => + case arg: TreeNode[_] if isOneOfChildren(arg) => val newChild = f(arg.asInstanceOf[BaseType]) if (newChild fastEquals arg) { arg @@ -195,7 +204,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. case s: Seq[_] => s.map { - case arg: TreeNode[_] if containsChild(arg) => + case arg: TreeNode[_] if isOneOfChildren(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -208,7 +217,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case null => null } case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => + case arg: TreeNode[_] if isOneOfChildren(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -220,7 +229,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case nonChild: AnyRef => nonChild case null => null }.view.force // `mapValues` is lazy and we need to force it to materialize - case arg: TreeNode[_] if containsChild(arg) => + case arg: TreeNode[_] if isOneOfChildren(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -297,7 +306,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if containsChild(arg) => + case arg: TreeNode[_] if isOneOfChildren(arg) => val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true @@ -305,7 +314,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { arg } - case Some(arg: TreeNode[_]) if containsChild(arg) => + case Some(arg: TreeNode[_]) if isOneOfChildren(arg) => val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true @@ -314,7 +323,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { Some(arg) } case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => + case arg: TreeNode[_] if isOneOfChildren(arg) => val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true @@ -326,7 +335,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => + case arg: TreeNode[_] if isOneOfChildren(arg) => val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true @@ -405,9 +414,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { - case tn: TreeNode[_] if containsChild(tn) => Nil + case tn: TreeNode[_] if isOneOfChildren(tn) => Nil case tn: TreeNode[_] => s"${tn.simpleString}" :: Nil - case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil + case seq: Seq[_] if isSubsetOfChildren(seq) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil case other => other :: Nil @@ -423,7 +432,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a string representation of the nodes in this tree, where each operator is numbered. - * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. + * The numbers can be used with [[TreeNode.apply apply]] to easily access specific subtrees. */ def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") @@ -571,9 +580,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { fieldNames.zip(fieldValues).map { // If the field value is a child, then use an int to encode it, represents the index of // this child in all children. - case (name, value: TreeNode[_]) if containsChild(value) => + case (name, value: TreeNode[_]) if isOneOfChildren(value) => name -> JInt(children.indexOf(value)) - case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) => + case (name, value: Seq[_]) if isSubsetOfChildren(value) => name -> JArray( value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index f9874088b5884..5841cc5fc7865 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,23 +19,30 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, Filter, LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.util._ /** * Provides helper methods for comparing plans. */ abstract class PlanTest extends SparkFunSuite { + + private def clearExprId(a: Attribute): Attribute = { + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + } + /** * Since attribute references are given globally unique ids during analysis, * we must normalize them to check if two different queries are identical. */ protected def normalizeExprIds(plan: LogicalPlan) = { - plan transformAllExpressions { - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) + plan.transform { + // A special case for Generate, its `generatorOutput` is not part of its expressions, thus is + // not reachable in `transformAllExpressions`, we should handle it separately. + case g: Generate => g.copy(generatorOutput = g.generatorOutput.map(clearExprId)) + }.transformAllExpressions { + case a: AttributeReference => clearExprId(a) + case a: Alias => Alias(a.child, a.name)(exprId = ExprId(0)) } } From 28cc1fd150bf43fdcd7c701eb3564f8136c46e26 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 Mar 2016 21:38:38 +0800 Subject: [PATCH 2/2] fix style --- .../scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a630ec7951bfb..97432f0fa0f8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeRef, TreeNode} +import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeRef} import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 5841cc5fc7865..7373cd77cbefc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Generate, Filter, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Generate, LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.util._ /**