From af6a0bfa581838def69ac44aaa0d5f40b20cc40e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 12 Mar 2016 12:41:29 +0800 Subject: [PATCH] QueryPlan sub-classes should override producedAttributes to fix missingInput --- .../scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 4 ++-- .../sql/catalyst/plans/logical/ScriptTransformation.scala | 2 +- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 3 +-- .../main/scala/org/apache/spark/sql/execution/Expand.scala | 3 +-- .../apache/spark/sql/execution/python/EvaluatePython.scala | 3 +-- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c222571a3464b..3fee5b37baf1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -110,7 +110,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. */ - def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) + final def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) /** * The set of all attributes that are input to this operator by its children. @@ -128,7 +128,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. */ - def missingInput: AttributeSet = references -- inputSet -- producedAttributes + final def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** * Runs [[transform]] with `rule` on all expressions present in this query operator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index 578027da776e5..e511abb0dbeec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -33,7 +33,7 @@ case class ScriptTransformation( output: Seq[Attribute], child: LogicalPlan, ioschema: ScriptInputOutputSchema) extends UnaryNode { - override def references: AttributeSet = AttributeSet(input.flatMap(_.references)) + override def producedAttributes: AttributeSet = AttributeSet(output) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09ea3fea6a694..a2f8129bddad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -519,8 +519,7 @@ case class Expand( output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def references: AttributeSet = - AttributeSet(projections.flatten.flatMap(_.references)) + override def producedAttributes: AttributeSet = AttributeSet(output) override def statistics: Statistics = { val sizeInBytes = super.statistics.sizeInBytes * projections.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index a84e180ad1dd8..108c81deedd28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -46,8 +46,7 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - override def references: AttributeSet = - AttributeSet(projections.flatten.flatMap(_.references)) + override def producedAttributes: AttributeSet = AttributeSet(output) private[this] val projection = (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 8c46516594a2d..e213d52ffb930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -45,8 +45,7 @@ case class EvaluatePython( def output: Seq[Attribute] = child.output :+ resultAttribute - // References should not include the produced attribute. - override def references: AttributeSet = udf.references + override def producedAttributes: AttributeSet = AttributeSet(resultAttribute) }