From e1919ccf1e0a878bfb320d41beeffbc7f4c9b5da Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Feb 2016 10:18:11 -0800 Subject: [PATCH 1/5] push missing attributes in Sort --- .../sql/catalyst/analysis/Analyzer.scala | 130 +++++++----------- .../sql/hive/execution/SQLQuerySuite.scala | 15 ++ 2 files changed, 63 insertions(+), 82 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4d53b232d5510..39c1874f8053d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -572,98 +572,65 @@ class Analyzer( // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(_, _, child) if !s.resolved && child.resolved => - val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child) - - if (missingResolvableAttrs.isEmpty) { - val unresolvableAttrs = s.order.filterNot(_.resolved) - logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}") - s // Nothing we can do here. Return original plan. - } else { - // Add the missing attributes into projectList of Project/Window or - // aggregateExpressions of Aggregate, if they are in the inputSet - // but not in the outputSet of the plan. - val newChild = child transformUp { - case p: Project => - p.copy(projectList = p.projectList ++ - missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains)) - case w: Window => - w.copy(projectList = w.projectList ++ - missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains)) - case a: Aggregate => - val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains) - val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains) - val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs - a.copy(aggregateExpressions = newAggregateExpressions) - case o => o - } - + case s @ Sort(order, _, child) if !s.resolved && child.resolved => + val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) + val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) + val missingAttrs = requiredAttrs -- child.outputSet + if (missingAttrs.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(child.output, - Sort(newOrdering, s.global, newChild)) + Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) + } else if (newOrder != order) { + s.copy(order = newOrder) + } else { + s } } /** - * Traverse the tree until resolving the sorting attributes - * Return all the resolvable missing sorting attributes - */ - @tailrec - private def collectResolvableMissingAttrs( - ordering: Seq[SortOrder], - plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + * Add the missing attributes into projectList of Project/Window or aggregateExpressions of + * Aggregate. + */ + private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { + if (missingAttrs.isEmpty) { + return plan + } plan match { - // Only Windows and Project have projectList-like attribute. - case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] => - val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child) - // If missingAttrs is non empty, that means we got it and return it; - // Otherwise, continue to traverse the tree. - if (missingAttrs.nonEmpty) { - (newOrdering, missingAttrs) - } else { - collectResolvableMissingAttrs(ordering, un.child) - } + // + case p: Project => + val missing = missingAttrs -- p.child.outputSet + Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) + case w: Window => + val missing = missingAttrs -- w.child.outputSet + w.copy(projectList = w.projectList ++ missingAttrs, + child = addMissingAttr(w.child, missing)) case a: Aggregate => - val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child) - // For Aggregate, all the order by columns must be specified in group by clauses - if (missingAttrs.nonEmpty && - missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) { - (newOrdering, missingAttrs) - } else { - // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes - (Seq.empty[SortOrder], Seq.empty[Attribute]) + // all the missing attributes should be grouping expressions + // TODO: push down AggregateExpression + missingAttrs.foreach { attr => + if (!a.groupingExpressions.contains(attr)) { + throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") + } } - // Jump over the following UnaryNode types - // The output of these types is the same as their child's output - case _: Distinct | - _: Filter | - _: RepartitionByExpression => - collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child) - // If hitting the other unsupported operators, we are unable to resolve it. - case other => (Seq.empty[SortOrder], Seq.empty[Attribute]) + val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs + a.copy(aggregateExpressions = newAggregateExpressions) + case f: UnaryNode => + addMissingAttr(f.child, missingAttrs) + case other => + throw new AnalysisException(s"Can't add $missingAttrs to $other") } } - /** - * Try to resolve the sort ordering and returns it with a list of attributes that are missing - * from the plan but are present in the child. - */ - private def resolveAndFindMissing( - ordering: Seq[SortOrder], - plan: LogicalPlan, - child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = - ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) - // Construct a set that contains all of the attributes that we need to evaluate the - // ordering. - val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) - // Figure out which ones are missing from the projection, so that we can add them and - // remove them after the sort. - val missingInProject = requiredAttributes -- plan.outputSet - // It is important to return the new SortOrders here, instead of waiting for the standard - // resolving process as adding attributes to the project below can actually introduce - // ambiguity that was not present before. - (newOrdering, missingInProject.toSeq) + private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { + val resolved = resolveExpression(expr, plan) + if (resolved.resolved) { + resolved + } else { + plan match { + case u: UnaryNode => resolveExpressionRecursively(resolved, u.child) + case other => resolved + } + } } } @@ -753,8 +720,7 @@ class Analyzer( filter } - case sort @ Sort(sortOrder, global, aggregate: Aggregate) - if aggregate.resolved => + case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6048b8f5a3998..be864f79d6b7e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -978,6 +978,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ("d", 1), ("c", 2) ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, sum(product) / sum(sum(product)) over (partition by area) as c1 + |from windowData group by area, month order by month, c1 + """.stripMargin), + Seq( + ("d", 1.0), + ("a", 1.0), + ("b", 0.4666666666666667), + ("b", 0.5333333333333333), + ("c", 0.45), + ("c", 0.55) + ).map(i => Row(i._1, i._2))) } // todo: fix this test case by reimplementing the function ResolveAggregateFunctions From bec639de876e7b4dfccbc17a7bb94c212443e5ff Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Feb 2016 12:42:41 -0800 Subject: [PATCH 2/5] fix bug --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 39c1874f8053d..29eff07978b34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -596,7 +596,6 @@ class Analyzer( return plan } plan match { - // case p: Project => val missing = missingAttrs -- p.child.outputSet Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) @@ -614,8 +613,8 @@ class Analyzer( } val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs a.copy(aggregateExpressions = newAggregateExpressions) - case f: UnaryNode => - addMissingAttr(f.child, missingAttrs) + case u: UnaryNode => + u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) case other => throw new AnalysisException(s"Can't add $missingAttrs to $other") } From c4607dd6f08801ff117768a89e3d13bb748b2d43 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Feb 2016 21:33:42 -0800 Subject: [PATCH 3/5] fix test --- .../org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ebf885a8fe484..f85ae24e0459b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -90,7 +90,7 @@ class AnalysisSuite extends AnalysisTest { .where(a > "str").select(a, b, c) .where(b > "str").select(a, b, c) .sortBy(b.asc, c.desc) - .select(a, b).select(a) + .select(a) checkAnalysis(plan1, expected1) // Case 2: all the missing attributes are in the leaf node From dce38575730ae92972d8c15e4a6d13983eeb0392 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 11 Feb 2016 10:58:41 -0800 Subject: [PATCH 4/5] address comments --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 29eff07978b34..0f341e46936b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -620,13 +619,18 @@ class Analyzer( } } + /** + * Resolve the expression on a specified logical plan and it's child (recursively), until + * the expression is resolved or meet a non-unary node or Subquery. + */ private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { val resolved = resolveExpression(expr, plan) if (resolved.resolved) { resolved } else { plan match { - case u: UnaryNode => resolveExpressionRecursively(resolved, u.child) + case u: UnaryNode if !u.isInstanceOf[Subquery] => + resolveExpressionRecursively(resolved, u.child) case other => resolved } } From 890940af816de0133158e937b5af1c926f2161cc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 12 Feb 2016 09:33:39 -0800 Subject: [PATCH 5/5] address comments --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0f341e46936b7..651c1f113e696 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -606,7 +606,7 @@ class Analyzer( // all the missing attributes should be grouping expressions // TODO: push down AggregateExpression missingAttrs.foreach { attr => - if (!a.groupingExpressions.contains(attr)) { + if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") } }