From 6f609fb2d844e2aaf4c809ef8c0fcd9e6eca38bb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 3 Mar 2016 17:54:58 -0800 Subject: [PATCH] address comments. --- .../apache/spark/sql/hive/SQLBuilder.scala | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 1f1776d3a6ee8..0d9a68342c643 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -208,9 +208,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi throw new UnsupportedOperationException(s"unsupported plan $node") } - private def sameOutput(left: Seq[Attribute], right: Seq[Attribute]): Boolean = - left.forall(a => right.exists(_.semanticEquals(a))) && - right.forall(a => left.exists(_.semanticEquals(a))) + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) /** * Turns a bunch of string segments into a single string and separate each segment by a space. @@ -243,26 +243,34 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } private def groupingSetToSQL( - plan: Aggregate, + agg: Aggregate, expand: Expand, project: Project): String = { - require(plan.groupingExpressions.length > 1) + assert(agg.groupingExpressions.length > 1) // The last column of Expand is always grouping ID val gid = expand.output.last - val groupByAttributes = plan.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) - val groupByExprs = - project.projectList.drop(project.child.output.length).map(_.asInstanceOf[Alias].child) - val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + val numOriginalOutput = project.child.output.length + // Assumption: Aggregate's groupingExpressions is composed of + // 1) the group by attributes' aliases + // 2) gid, which is always the last one + val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) + // Assumption: Project's projectList is composed of + // 1) the original output (Project's child.output), + // 2) the aliases of the original group by attributes, which could be expressions + val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) val groupingSQL = groupByExprs.map(_.sql).mkString(", ") + // a map from the alias name to the original group by expresions/attributes + val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + val groupingSet = expand.projections.map { project => - // Assumption: expand.projections are composed of + // Assumption: expand.projections is composed of // 1) the original output (Project's child.output), // 2) group by attributes(or null literal) // 3) gid, which is always the last one in each project in Expand - project.dropRight(1).collect { + project.drop(numOriginalOutput).dropRight(1).collect { case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) } } @@ -270,13 +278,14 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi "GROUPING SETS(" + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" - val aggExprs = plan.aggregateExpressions.map { case expr => + val aggExprs = agg.aggregateExpressions.map { case expr => expr.transformDown { // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. - case ar: AttributeReference if ar eq gid => GroupingID(Nil) - case a @ Alias(_ @ Cast(BitwiseAnd( + case ar: AttributeReference if ar == gid => GroupingID(Nil) + case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) + case a @ Cast(BitwiseAnd( ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), - Literal(1, IntegerType)), ByteType), name) if ar == gid => + Literal(1, IntegerType)), ByteType) if ar == gid => // for converting an expression to its original SQL format grouping(col) val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] val groupingCol = groupByExprs.lift(idx) @@ -285,17 +294,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } else { throw new UnsupportedOperationException(s"unsupported operator $a") } - case a @ Alias(ar: AttributeReference, _) if groupByAttrMap.contains(ar) => - groupByAttrMap(ar) - case ar: AttributeReference if groupByAttrMap.contains(ar) => - groupByAttrMap(ar) } } build( "SELECT", aggExprs.map(_.sql).mkString(", "), - if (plan.child == OneRowRelation) "" else "FROM", + if (agg.child == OneRowRelation) "" else "FROM", toSQL(project.child), "GROUP BY", groupingSQL,