From d7424eb2bf91bf7ca087c1c232850f7d8feaa69f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 2 Mar 2016 21:23:15 +0800 Subject: [PATCH 1/2] clean up --- .../sql/catalyst/analysis/Analyzer.scala | 34 +++++++------------ .../plans/logical/basicOperators.scala | 23 ++++--------- 2 files changed, 18 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 876aa0eae0e90..9630e9ab212d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -181,8 +181,8 @@ class Analyzer( case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) - case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => - g.withNewAggs(assignAliases(g.aggregations)) + case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => + g.copy(aggregations = assignAliases(g.aggregations)) case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) if child.resolved && hasUnresolvedAlias(groupByExprs) => @@ -250,30 +250,16 @@ class Analyzer( val nonNullBitmask = x.bitmasks.reduce(_ & _) - val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) => - if ((nonNullBitmask & 1 << idx) == 0) { - (a -> a.toAttribute.withNullability(true)) - } else { - (a -> a.toAttribute) - } - }.toMap + val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0) + } val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => - // collect all the found AggregateExpression, so we can check an expression is part of - // any AggregateExpression or not. - val aggsBuffer = ArrayBuffer[Expression]() - // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. - def isPartOfAggregation(e: Expression): Boolean = { - aggsBuffer.exists(a => a.find(_ eq e).isDefined) - } expr.transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. - case e: AggregateExpression => - aggsBuffer += e - e - case e if isPartOfAggregation(e) => e + case e: AggregateExpression => e case e: GroupingID => if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) { gid @@ -292,12 +278,16 @@ class Analyzer( s"in grouping columns ${x.groupByExprs.mkString(",")}") } case e => - groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) + val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) + if (index == -1) { + e + } else { + groupByAttributes(index) + } }.asInstanceOf[NamedExpression] } val child = Project(x.child.output ++ groupByAliases, x.child) - val groupByAttributes = groupByAliases.map(attributeMap(_)) Aggregate( groupByAttributes :+ VirtualColumn.groupingIdAttribute, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e81a0f9487469..522348735aadf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -533,20 +533,6 @@ case class Expand( } } -trait GroupingAnalytics extends UnaryNode { - - def groupByExprs: Seq[Expression] - def aggregations: Seq[NamedExpression] - - override def output: Seq[Attribute] = aggregations.map(_.toAttribute) - - // Needs to be unresolved before its translated to Aggregate + Expand because output attributes - // will change in analysis. - override lazy val resolved: Boolean = false - - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics -} - /** * A GROUP BY clause with GROUPING SETS can generate a result set equivalent * to generated by a UNION ALL of multiple simple GROUP BY clauses. @@ -565,10 +551,13 @@ case class GroupingSets( bitmasks: Seq[Int], groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends UnaryNode { + + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false } case class Pivot( From 9ba324685f9671119d09e2dd77d0d97fb5f3bd8a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 Mar 2016 08:59:17 +0800 Subject: [PATCH 2/2] fix a bug --- .../spark/sql/catalyst/analysis/Analyzer.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9630e9ab212d3..36eb59ef5ef9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -255,11 +255,21 @@ class Analyzer( } val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => + // collect all the found AggregateExpression, so we can check an expression is part of + // any AggregateExpression or not. + val aggsBuffer = ArrayBuffer[Expression]() + // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. + def isPartOfAggregation(e: Expression): Boolean = { + aggsBuffer.exists(a => a.find(_ eq e).isDefined) + } expr.transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. - case e: AggregateExpression => e + case e: AggregateExpression => + aggsBuffer += e + e + case e if isPartOfAggregation(e) => e case e: GroupingID => if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) { gid