From fc26b49f1f46561a16effc6f3b6334d1024bb644 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 24 Aug 2014 13:36:22 -0700 Subject: [PATCH 1/7] Add AttributeSet class, remove references from Expression. --- .../sql/catalyst/analysis/Analyzer.scala | 6 +- .../sql/catalyst/analysis/unresolved.scala | 2 +- .../catalyst/expressions/AttributeSet.scala | 66 +++++++++++++++++++ .../catalyst/expressions/BoundAttribute.scala | 2 - .../sql/catalyst/expressions/Expression.scala | 6 +- .../spark/sql/catalyst/expressions/Rand.scala | 1 - .../sql/catalyst/expressions/ScalaUdf.scala | 1 - .../sql/catalyst/expressions/SortOrder.scala | 1 - .../catalyst/expressions/WrapDynamic.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 25 ++++--- .../catalyst/expressions/complexTypes.scala | 2 +- .../sql/catalyst/expressions/generators.scala | 2 - .../sql/catalyst/expressions/literals.scala | 4 +- .../expressions/namedExpressions.scala | 6 +- .../catalyst/expressions/nullFunctions.scala | 3 - .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/expressions/sets.scala | 5 -- .../expressions/stringOperations.scala | 2 - .../sql/catalyst/optimizer/Optimizer.scala | 12 ++-- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../plans/logical/basicOperators.scala | 16 ++--- .../plans/physical/partitioning.scala | 2 - .../sql/catalyst/trees/TreeNodeSuite.scala | 1 - .../spark/sql/execution/pythonUdfs.scala | 1 - .../spark/sql/hive/HiveStrategies.scala | 4 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 5 -- .../hive/execution/HiveResolutionSuite.scala | 9 ++- 27 files changed, 119 insertions(+), 77 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c18d7858f0a43..4a9524074132e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -132,7 +132,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) val resolved = unresolved.flatMap(child.resolveChildren) - val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet + val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) val missingInProject = requiredAttributes -- p.output if (missingInProject.nonEmpty) { @@ -152,8 +152,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ) logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve).toSet - val missingInAggs = resolved -- a.outputSet + val resolved = unresolved.flatMap(groupingRelation.resolve) + val missingInAggs = resolved.filterNot(a.outputSet.contains) logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") if (missingInAggs.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a0e25775da6dd..b5cf1c1cd2388 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -39,6 +39,7 @@ case class UnresolvedRelation( alias: Option[String] = None) extends LeafNode { override def output = Nil override lazy val resolved = false + def reference = Set.empty } /** @@ -66,7 +67,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E override def dataType = throw new UnresolvedException(this, "dataType") override def foldable = throw new UnresolvedException(this, "foldable") override def nullable = throw new UnresolvedException(this, "nullable") - override def references = children.flatMap(_.references).toSet override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala new file mode 100644 index 0000000000000..d61b0252ba722 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +class AttributeEquals(val a: Attribute) { + override def hashCode() = a.exprId.hashCode() + override def equals(other: Any) = other match { + case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId + case otherAttribute => false + } +} + +object AttributeSet { + def apply(baseSet: Seq[Attribute]) = { + new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet) + } + +// def apply(baseSet: Set[Attribute]) = { +// new AttributeSet(baseSet.map(new AttributeEquals(_))) +// } +} + +class AttributeSet(val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] { + def contains(elem: NamedExpression): Boolean = + baseSet.contains(new AttributeEquals(elem.toAttribute)) + + def +(elem: Attribute): AttributeSet = + new AttributeSet(baseSet + new AttributeEquals(elem)) + + def -(elem: Attribute): AttributeSet = + new AttributeSet(baseSet - new AttributeEquals(elem)) + + def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator + + def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + + def --(other: Traversable[NamedExpression]) = + new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + + def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + + override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + + def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) + + override def nonEmpty = baseSet.nonEmpty + + override def toSeq = baseSet.toSeq.map(_.a) + + override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 0913f15888780..54c6baf1af3bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,8 +32,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - override def references = Set.empty - override def toString = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ba62dabe3dd6a..70507e7ee2be8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] { */ def foldable: Boolean = false def nullable: Boolean - def references: Set[Attribute] + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ def eval(input: Row = null): EvaluatedType @@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def foldable = left.foldable && right.foldable - override def references = left.references ++ right.references - override def toString = s"($left $symbol $right)" } @@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - override def references = child.references + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index 38f836f0a1a0e..851db95b9177e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType case object Rand extends LeafExpression { override def dataType = DoubleType override def nullable = false - override def references = Set.empty private[this] lazy val rand = new Random diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 95633dd0c9870..63ac2a608b6ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -24,7 +24,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi type EvaluatedType = Any - def references = children.flatMap(_.references).toSet def nullable = true /** This method has been generated by this script diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d2b7685e73065..d00b2ac09745c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -31,7 +31,6 @@ case object Descending extends SortDirection case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { - override def references = child.references override def dataType = child.dataType override def nullable = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index eb8898900d6a5..1eb55715794a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { type EvaluatedType = DynamicRow def nullable = false - def references = children.toSet + def dataType = DynamicType override def eval(input: Row): DynamicRow = input match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 613b87ca98d97..dbc0c2965a805 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -78,7 +78,7 @@ abstract class AggregateFunction /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - override def references = base.references + override def nullable = base.nullable override def dataType = base.dataType @@ -89,7 +89,7 @@ abstract class AggregateFunction } case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = true override def dataType = child.dataType override def toString = s"MIN($child)" @@ -119,7 +119,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr } case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = true override def dataType = child.dataType override def toString = s"MAX($child)" @@ -149,7 +149,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr } case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"COUNT($child)" @@ -166,7 +166,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate def this() = this(null) override def children = expressions - override def references = expressions.flatMap(_.references).toSet + override def nullable = false override def dataType = LongType override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" @@ -184,7 +184,6 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress def this() = this(null) override def children = expressions - override def references = expressions.flatMap(_.references).toSet override def nullable = false override def dataType = ArrayType(expressions.head.dataType) override def toString = s"AddToHashSet(${expressions.mkString(",")})" @@ -219,7 +218,6 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression def this() = this(null) override def children = inputSet :: Nil - override def references = inputSet.references override def nullable = false override def dataType = LongType override def toString = s"CombineAndCount($inputSet)" @@ -248,7 +246,7 @@ case class CombineSetsAndCountFunction( case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -257,7 +255,7 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -266,7 +264,7 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = LongType override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" @@ -284,7 +282,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = DoubleType override def toString = s"AVG($child)" @@ -304,7 +302,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"SUM($child)" @@ -322,7 +320,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def references = child.references + override def nullable = false override def dataType = child.dataType override def toString = s"SUM(DISTINCT $child)" @@ -331,7 +329,6 @@ case class SumDistinct(child: Expression) } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references override def nullable = true override def dataType = child.dataType override def toString = s"FIRST($child)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index c1154eb81c319..dafd745ec96c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { /** `Null` is returned for invalid ordinals. */ override def nullable = true override def foldable = child.foldable && ordinal.foldable - override def references = children.flatMap(_.references).toSet + def dataType = child.dataType match { case ArrayType(dt, _) => dt case MapType(_, vt, _) => vt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e99c5b452d183..9c865254e0be9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -47,8 +47,6 @@ abstract class Generator extends Expression { override def nullable = false - override def references = children.flatMap(_.references).toSet - /** * Should be overridden by specific generators. Called only once for each instance to ensure * that rule application does not change the output schema of a generator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e15e16d633365..a8c2396d62632 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -52,7 +52,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { override def foldable = true def nullable = value == null - def references = Set.empty + override def toString = if (value != null) value.toString else "null" @@ -66,8 +66,6 @@ case class MutableLiteral(var value: Any, nullable: Boolean = true) extends Leaf val dataType = Literal(value).dataType - def references = Set.empty - def update(expression: Expression, input: Row) = { value = expression.eval(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 02d04762629f5..7c4b9d4847e26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression { def toAttribute = this def newInstance: Attribute - override def references = Set(this) + } /** @@ -85,7 +85,7 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable - override def references = child.references + override def toAttribute = { if (resolved) { @@ -116,6 +116,8 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { + override def references = AttributeSet(this :: Nil) + override def equals(other: Any) = other match { case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index e88c5d4fa178a..086d0a3e073e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -26,7 +26,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ def nullable = !children.exists(!_.nullable) - def references = children.flatMap(_.references).toSet // Coalesce is foldable if all children are foldable. override def foldable = !children.exists(!_.foldable) @@ -53,7 +52,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - def references = child.references override def foldable = child.foldable def nullable = false @@ -65,7 +63,6 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - def references = child.references override def foldable = child.foldable def nullable = false override def toString = s"IS NOT NULL $child" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5976b0ddf3e03..1313ccd120c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate { */ case class In(value: Expression, list: Seq[Expression]) extends Predicate { def children = value +: list - def references = children.flatMap(_.references).toSet + def nullable = true // TODO: Figure out correct nullability semantics of IN. override def toString = s"$value IN ${list.mkString("(", ",", ")")}" @@ -197,7 +197,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi def children = predicate :: trueValue :: falseValue :: Nil override def nullable = trueValue.nullable || falseValue.nullable - def references = children.flatMap(_.references).toSet + override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType def dataType = { if (!resolved) { @@ -239,7 +239,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen(branches: Seq[Expression]) extends Expression { type EvaluatedType = Any def children = branches - def references = children.flatMap(_.references).toSet + def dataType = { if (!resolved) { throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index e6c570b47bee2..3d4c4a8853c12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -26,8 +26,6 @@ import org.apache.spark.util.collection.OpenHashSet case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any - def references = Set.empty - def nullable = false // We are currently only using these Expressions internally for aggregation. However, if we ever @@ -53,9 +51,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { def nullable = set.nullable def dataType = set.dataType - - def references = (item.flatMap(_.references) ++ set.flatMap(_.references)).toSet - def eval(input: Row): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 97fc3a3b14b88..c2a3a5ca3ca8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -226,8 +226,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends if (str.dataType == BinaryType) str.dataType else StringType } - def references = children.flatMap(_.references).toSet - override def children = str :: pos :: len :: Nil @inline diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5f86d6047cb9c..90c9430c44935 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -59,14 +59,14 @@ object Optimizer extends RuleExecutor[LogicalPlan] { object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Eliminate attributes that are not needed to calculate the specified aggregates. - case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + case a @ Aggregate(_, _, child) if (child.outputSet -- AttributeSet(a.references.toSeq)).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) // Eliminate unneeded attributes from either side of a Join. case Project(projectList, Join(left, right, joinType, condition)) => // Collect the list of all references required either above or to evaluate the condition. - val allReferences: Set[Attribute] = - projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty) + val allReferences: AttributeSet = + AttributeSet(projectList.flatMap(_.references.iterator)) ++ condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) @@ -76,8 +76,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Eliminate unneeded attributes from right side of a LeftSemiJoin. case Join(left, right, LeftSemi, condition) => // Collect the list of all references required to evaluate the condition. - val allReferences: Set[Attribute] = - condition.map(_.references).getOrElse(Set.empty) + val allReferences: AttributeSet = + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) Join(left, prunedChild(right, allReferences), LeftSemi, condition) @@ -104,7 +104,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: Set[Attribute]) = + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0988b0c6d990c..1e177e28f80b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType} @@ -29,7 +29,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Returns the set of attributes that are output by this node. */ - def outputSet: Set[Attribute] = output.toSet + def outputSet: AttributeSet = AttributeSet(output) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 3cb407217c4c3..6ad605475b09c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { def output = projectList.map(_.toAttribute) - def references = projectList.flatMap(_.references).toSet + def references = projectList.flatMap(_.references.iterator).toSet } /** @@ -61,12 +61,12 @@ case class Generate( if (join) child.output ++ generatorOutput else generatorOutput override def references = - if (join) child.outputSet else generator.references + if (join) child.outputSet.iterator.toSet else generator.references.iterator.toSet } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = condition.references + override def references = condition.references.iterator.toSet } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -86,7 +86,7 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def references = condition.map(_.references).getOrElse(Set.empty) + override def references = condition.map(_.references.iterator.toSet).getOrElse(Set.empty) override def output = { joinType match { @@ -143,7 +143,7 @@ case class WriteToFile( case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = order.flatMap(_.references).toSet + override def references = order.flatMap(_.references.iterator).toSet } case class Aggregate( @@ -154,12 +154,12 @@ case class Aggregate( override def output = aggregateExpressions.map(_.toAttribute) override def references = - (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet + (groupingExpressions ++ aggregateExpressions).flatMap(_.references.iterator).toSet } case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = limitExpr.references + override def references = limitExpr.references.iterator.toSet } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { @@ -204,7 +204,7 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: case class Distinct(child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = child.outputSet + override def references = child.outputSet.iterator.toSet } case object NoRelation extends LeafNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4bb022cf238af..2533af4c2a091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -139,7 +139,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) with Partitioning { override def children = expressions - override def references = expressions.flatMap(_.references).toSet override def nullable = false override def dataType = IntegerType @@ -179,7 +178,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) with Partitioning { override def children = ordering - override def references = ordering.flatMap(_.references).toSet override def nullable = false override def dataType = IntegerType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 6344874538d67..296202543e2ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.types.{StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { def children = optKey.toSeq - def references = Set.empty[Attribute] def nullable = true def dataType = NullType override lazy val resolved = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index b92091b560b1c..cbbe7cee84ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -49,7 +49,6 @@ private[spark] case class PythonUDF( override def toString = s"PythonUDF#$name(${children.mkString(",")})" def nullable: Boolean = true - def references: Set[Attribute] = children.flatMap(_.references).toSet override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 389ace726d205..eef4638c669ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -176,9 +176,9 @@ private[hive] trait HiveStrategies { case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. - val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val partitionKeyIds = AttributeSet(relation.partitionKeys) val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.map(_.exprId).subsetOf(partitionKeyIds) + _.references.subsetOf(partitionKeyIds) } pruneFilterProject( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index c6497a15efa0c..7d1ad53d8bdb3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -88,7 +88,6 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu type EvaluatedType = Any def nullable = true - def references = children.flatMap(_.references).toSet lazy val function = createFunction[UDFType]() @@ -229,8 +228,6 @@ private[hive] case class HiveGenericUdaf( def nullable: Boolean = true - def references: Set[Attribute] = children.map(_.references).flatten.toSet - override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" def newInstance() = new HiveUdafFunction(functionClassName, children, this) @@ -253,8 +250,6 @@ private[hive] case class HiveGenericUdtf( children: Seq[Expression]) extends Generator with HiveInspectors with HiveFunctionFactory { - override def references = children.flatMap(_.references).toSet - @transient protected lazy val function: GenericUDTF = createFunction() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 6b3ffd1c0ffe2..b6be6bc1bfefe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) case class Nested(a: Int, B: Int) +case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -57,6 +57,13 @@ class HiveResolutionSuite extends HiveComparisonTest { .registerTempTable("caseSensitivityTest") sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + + println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution) + + sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect() + + // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a") + } test("nested repeated resolution") { From d6e16bebd519a9ebeb045292ede936e4782a3df9 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 24 Aug 2014 14:13:02 -0700 Subject: [PATCH 2/7] Remove more instances of "def references" and normal sets of attributes. --- .../catalyst/expressions/AttributeSet.scala | 6 ++++ .../catalyst/plans/logical/LogicalPlan.scala | 11 +------ .../plans/logical/ScriptTransformation.scala | 4 +-- .../plans/logical/basicOperators.scala | 29 ++++--------------- .../catalyst/plans/logical/partitioning.scala | 4 --- .../plans/physical/partitioning.scala | 1 + .../org/apache/spark/sql/SQLContext.scala | 7 +++-- .../columnar/InMemoryColumnarTableScan.scala | 2 -- .../spark/sql/execution/SparkPlan.scala | 3 +- .../spark/sql/hive/HiveStrategies.scala | 4 +-- 10 files changed, 21 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index d61b0252ba722..da9e43a10ae92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -36,6 +36,12 @@ object AttributeSet { } class AttributeSet(val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] { + + override def equals(other: Any) = other match { + case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + case _ => false + } + def contains(elem: NamedExpression): Boolean = baseSet.contains(new AttributeEquals(elem.toAttribute)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 278569f0cb14a..8616ac45b0e95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -45,17 +45,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product ) - /** - * Returns the set of attributes that are referenced by this node - * during evaluation. - */ - def references: Set[Attribute] - /** * Returns the set of attributes that this node takes as * input from its children. */ - lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet + lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output)) /** * Returns true if this expression and all its children have been resolved to a specific schema @@ -126,9 +120,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { override lazy val statistics: Statistics = throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") - - // Leaf nodes by definition cannot reference any input attributes. - override def references = Set.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index d3f9d0fb93237..f325578063f99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -30,6 +30,4 @@ case class ScriptTransformation( input: Seq[Expression], script: String, output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - def references = input.flatMap(_.references).toSet -} + child: LogicalPlan) extends UnaryNode \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 6ad605475b09c..4adfb189372d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { def output = projectList.map(_.toAttribute) - def references = projectList.flatMap(_.references.iterator).toSet } /** @@ -59,14 +58,10 @@ case class Generate( override def output = if (join) child.output ++ generatorOutput else generatorOutput - - override def references = - if (join) child.outputSet.iterator.toSet else generator.references.iterator.toSet } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = condition.references.iterator.toSet } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -76,8 +71,6 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } - - override def references = Set.empty } case class Join( @@ -86,8 +79,6 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def references = condition.map(_.references.iterator.toSet).getOrElse(Set.empty) - override def output = { joinType match { case LeftSemi => @@ -106,8 +97,6 @@ case class Join( case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { def output = left.output - - def references = Set.empty } case class InsertIntoTable( @@ -118,7 +107,6 @@ case class InsertIntoTable( extends LogicalPlan { // The table being inserted into is a child for the purposes of transformations. override def children = table :: child :: Nil - override def references = Set.empty override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { @@ -130,20 +118,17 @@ case class InsertIntoCreatedTable( databaseName: Option[String], tableName: String, child: LogicalPlan) extends UnaryNode { - override def references = Set.empty override def output = child.output } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - override def references = Set.empty override def output = child.output } case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = order.flatMap(_.references.iterator).toSet } case class Aggregate( @@ -152,19 +137,20 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { + /** The set of all AttributeReferences required for this aggregation. */ + def references = + AttributeSet( + groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references)) + override def output = aggregateExpressions.map(_.toAttribute) - override def references = - (groupingExpressions ++ aggregateExpressions).flatMap(_.references.iterator).toSet } case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = limitExpr.references.iterator.toSet } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output = child.output.map(_.withQualifiers(alias :: Nil)) - override def references = Set.empty } /** @@ -191,20 +177,16 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { a.qualifiers) case other => other } - - override def references = Set.empty } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = Set.empty } case class Distinct(child: LogicalPlan) extends UnaryNode { override def output = child.output - override def references = child.outputSet.iterator.toSet } case object NoRelation extends LeafNode { @@ -213,5 +195,4 @@ case object NoRelation extends LeafNode { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output = left.output - override def references = Set.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 7146fbd540f29..72b0c5c8e7a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -31,13 +31,9 @@ abstract class RedistributeData extends UnaryNode { case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData { - - def references = sortExpressions.flatMap(_.references).toSet } case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) extends RedistributeData { - - def references = partitionExpressions.flatMap(_.references).toSet } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2533af4c2a091..ccb0df113c063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -71,6 +71,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + // TODO: This is not really valid... def clustering = ordering.map(_.child).toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index af9f7c62a1d25..db20893bcbd52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -344,8 +344,8 @@ class SQLContext(@transient val sparkContext: SparkContext) prunePushedDownFilters: Seq[Expression] => Seq[Expression], scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - val projectSet = projectList.flatMap(_.references).toSet - val filterSet = filterPredicates.flatMap(_.references).toSet + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And) // Right now we still use a projection even if the only evaluation is applying an alias @@ -354,7 +354,8 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO: Decouple final output schema from expression evaluation so this copy can be // avoided safely. - if (projectList.toSet == projectSet && filterSet.subsetOf(projectSet)) { + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { // When it is possible to just use column pruning to get the right projection and // when the columns of this projection are enough to evaluate all filter conditions, // just do a scan followed by a filter, with no extra project. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index e63b4903041f6..24e88eea3189e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -79,8 +79,6 @@ private[sql] case class InMemoryRelation( override def children = Seq.empty - override def references = Set.empty - override def newInstance() = { new InMemoryRelation( output.map(_.newInstance), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 21cbbc9772a00..7d33ea5b021e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -141,10 +141,9 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ extends LogicalPlan with MultiInstanceRelation { def output = alreadyPlanned.output - override def references = Set.empty override def children = Nil - override final def newInstance: this.type = { + override final def newInstance(): this.type = { SparkLogicalPlan( alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index eef4638c669ff..10fa8314c9156 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -79,9 +79,9 @@ private[hive] trait HiveStrategies { hiveContext.convertMetastoreParquet => // Filter out all predicates that only deal with partition keys - val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val partitionsKeys = AttributeSet(relation.partitionKeys) val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.map(_.exprId).subsetOf(partitionKeyIds) + _.references.subsetOf(partitionsKeys) } // We are going to throw the predicates and projection back at the whole optimization From cae5d22c518a6cd4698192eb3dc1569abe5dd41d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 24 Aug 2014 14:32:11 -0700 Subject: [PATCH 3/7] remove more references implementations --- .../org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 -- .../scala/org/apache/spark/sql/execution/debug/package.scala | 2 -- .../main/scala/org/apache/spark/sql/execution/pythonUdfs.scala | 1 - 3 files changed, 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8d90614e4501a..fcd392e1c5d44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -93,8 +93,6 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def children = left :: right :: Nil - override def references = left.references ++ right.references - override def dataType = left.dataType override def eval(input: Row): Any = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index f31df051824d7..5b896c55b7393 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -58,8 +58,6 @@ package object debug { } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { - def references = Set.empty - def output = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index cbbe7cee84ce1..aef6ebf86b1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -112,7 +112,6 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() - def references = Set.empty def output = child.output :+ resultAttribute } From d577cc7f2212761558f459ed9bea60f0476e2ffb Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 25 Aug 2014 19:10:56 -0700 Subject: [PATCH 4/7] Scaladoc --- .../catalyst/expressions/AttributeSet.scala | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index da9e43a10ae92..624bf76de888f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -class AttributeEquals(val a: Attribute) { +protected class AttributeEquals(val a: Attribute) { override def hashCode() = a.exprId.hashCode() override def equals(other: Any) = other match { case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId @@ -26,47 +26,71 @@ class AttributeEquals(val a: Attribute) { } object AttributeSet { + /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */ def apply(baseSet: Seq[Attribute]) = { new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet) } - -// def apply(baseSet: Set[Attribute]) = { -// new AttributeSet(baseSet.map(new AttributeEquals(_))) -// } } -class AttributeSet(val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] { +/** + * A Set designed to hold [[AttributeReference]] objects, that performs equality checking using + * expression id instead of standard java equality. Using expression id means that these + * sets will correctly test for membership, even when the AttributeReferences in question differ + * cosmetically (e.g., the names have different capitalizations). + */ +class AttributeSet protected (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] { + /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any) = other match { case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) case _ => false } + /** Returns true if this set contains an Attribute with the same expression id as `elem` */ def contains(elem: NamedExpression): Boolean = baseSet.contains(new AttributeEquals(elem.toAttribute)) + /** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */ def +(elem: Attribute): AttributeSet = new AttributeSet(baseSet + new AttributeEquals(elem)) + /** Returns a new [[AttributeSet]] that does not contain `elem`. */ def -(elem: Attribute): AttributeSet = new AttributeSet(baseSet - new AttributeEquals(elem)) + /** Returns an iterator containing all of the attributes in the set. */ def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator + /** + * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in + * `other`. + */ def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + /** + * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found + * in `other`. + */ def --(other: Traversable[NamedExpression]) = new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + /** + * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found + * in `other`. + */ def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + /** + * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to + * true. + */ override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + /** + * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in + * `this` and `other`. + */ def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) - override def nonEmpty = baseSet.nonEmpty - - override def toSeq = baseSet.toSeq.map(_.a) - override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) -} \ No newline at end of file +} From 40ce7f66e806c6621ebb69269caaae4b38d5fb50 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 25 Aug 2014 19:14:30 -0700 Subject: [PATCH 5/7] style --- .../spark/sql/catalyst/expressions/AttributeSet.scala | 2 +- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 ++++-- .../sql/catalyst/plans/logical/ScriptTransformation.scala | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 624bf76de888f..fc73a46c34c0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -51,7 +51,7 @@ class AttributeSet protected (val baseSet: Set[AttributeEquals]) extends Travers baseSet.contains(new AttributeEquals(elem.toAttribute)) /** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */ - def +(elem: Attribute): AttributeSet = + def +(elem: Attribute): AttributeSet = // scalastyle:ignore new AttributeSet(baseSet + new AttributeEquals(elem)) /** Returns a new [[AttributeSet]] that does not contain `elem`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 90c9430c44935..ddd4b3755d629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -59,14 +59,16 @@ object Optimizer extends RuleExecutor[LogicalPlan] { object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Eliminate attributes that are not needed to calculate the specified aggregates. - case a @ Aggregate(_, _, child) if (child.outputSet -- AttributeSet(a.references.toSeq)).nonEmpty => + case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) // Eliminate unneeded attributes from either side of a Join. case Project(projectList, Join(left, right, joinType, condition)) => // Collect the list of all references required either above or to evaluate the condition. val allReferences: AttributeSet = - AttributeSet(projectList.flatMap(_.references.iterator)) ++ condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) + AttributeSet( + projectList.flatMap(_.references.iterator)) ++ + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index f325578063f99..4460c86ed9026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -30,4 +30,4 @@ case class ScriptTransformation( input: Seq[Expression], script: String, output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode \ No newline at end of file + child: LogicalPlan) extends UnaryNode From 3ae52880eb4913a7a28c58e77a8124671e6e87dd Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 26 Aug 2014 13:42:06 -0700 Subject: [PATCH 6/7] review comments --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 1 - .../apache/spark/sql/catalyst/expressions/AttributeSet.scala | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index b5cf1c1cd2388..a2c61c65487cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -39,7 +39,6 @@ case class UnresolvedRelation( alias: Option[String] = None) extends LeafNode { override def output = Nil override lazy val resolved = false - def reference = Set.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index fc73a46c34c0b..73b5e74289f14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -37,6 +37,11 @@ object AttributeSet { * expression id instead of standard java equality. Using expression id means that these * sets will correctly test for membership, even when the AttributeReferences in question differ * cosmetically (e.g., the names have different capitalizations). + * + * Note that we do not override equality for Attribute references as it is really weird when + * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests, + * and also makes doing transformations hard (we always try keep older trees instead of new ones + * when the transformation was a no-op). */ class AttributeSet protected (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] { From 1c0dae57e29a9723d7614e98a1b5084460117a6b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 26 Aug 2014 14:59:44 -0700 Subject: [PATCH 7/7] work on serialization bug. --- .../spark/sql/catalyst/expressions/AttributeSet.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 73b5e74289f14..c3a08bbdb6bc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -43,7 +43,8 @@ object AttributeSet { * and also makes doing transformations hard (we always try keep older trees instead of new ones * when the transformation was a no-op). */ -class AttributeSet protected (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] { +class AttributeSet private (val baseSet: Set[AttributeEquals]) + extends Traversable[Attribute] with Serializable { /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any) = other match { @@ -98,4 +99,8 @@ class AttributeSet protected (val baseSet: Set[AttributeEquals]) extends Travers def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) + + // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all + // sorts of things in its closure. + override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq }