From bc0c0309b1ec0d5303b744025032070f00c2bc9c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 20 Feb 2016 07:27:52 -0800 Subject: [PATCH 01/15] SQL generation support for cube, rollup, and grouping set --- .../sql/catalyst/expressions/grouping.scala | 7 ++ .../apache/spark/sql/hive/SQLBuilder.scala | 57 +++++++++++- .../sql/hive/LogicalPlanToSQLSuite.scala | 91 +++++++++++++++++++ 3 files changed, 154 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index a204060630050..be9a46f70bc1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -63,4 +63,11 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une override def children: Seq[Expression] = groupByExprs override def dataType: DataType = IntegerType override def nullable: Boolean = false + + // TODO: remove this when SPARK-12799 is resolved. That will provide a general to-sql solution + // for all the expressions. + override def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"grouping_id($childrenSQL)" + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index bf5edb4759fbd..7259b78b5b733 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -74,6 +74,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Project => projectToSQL(p, isDistinct = false) + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) => + groupingSetToSQL(a, e, p) + case p: Aggregate => aggregateToSQL(p) @@ -172,6 +175,58 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) } + private def groupingSetToSQL( + plan: Aggregate, + expand: Expand, + project: Project): String = { + // In cube/rollup/groupingsets, Analyzer creates new aliases for all group by expressions. + // Since conversion from attribute back SQL ignore expression IDs, the alias of attribute + // references are ignored in aliasMap + val aliasMap = AttributeMap(project.projectList.collect { + case a @ Alias(child, name) if !child.isInstanceOf[AttributeReference] => (a.toAttribute, a) + }) + + val aggExprs = plan.aggregateExpressions.map{ + // VirtualColumn.groupingIdName is added by Analyzer, and thus remove it. + case a @ Alias(child: AttributeReference, name) + if child.name == VirtualColumn.groupingIdName => + Alias(GroupingID(Nil), name)() + case a @ Alias(child: AttributeReference, name) if aliasMap.contains(child) => + aliasMap(child).child + case o => o + } + + val groupingExprs = plan.groupingExpressions.filterNot { + case a: NamedExpression => a.name == VirtualColumn.groupingIdName + case o => false + }.map { + case a: AttributeReference if aliasMap.contains(a) => aliasMap(a).child + case o => o + } + + val groupingSQL = groupingExprs.map(_.sql).mkString(", ") + + val groupingSet = expand.projections.map(_.filter { + case _: Literal => false + case e: Expression if plan.groupingExpressions.exists(_.semanticEquals(e)) => true + case _ => false + }.map { + case a: AttributeReference if aliasMap.contains(a) => aliasMap(a).child + case o => o + }) + + build( + "SELECT", + aggExprs.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(project.child), + if (groupingSQL.isEmpty) "" else "GROUP BY", + groupingSQL, + "GROUPING SETS", + "(" + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" + ) + } + object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( Batch("Canonicalizer", FixedPoint(100), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index dc8ac7e47ffec..52fd4b24ed3e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -137,6 +137,97 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT DISTINCT id FROM t0") } + test("rollup/cube #1") { + // Original logical plan: + // Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46], + // [(count(1),mode=Complete,isDistinct=false) AS cnt#43L, + // (key#17L % cast(5 as bigint))#47L AS _c1#45L, + // grouping__id#46 AS _c2#44] + // +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0), + // List(key#17L, value#18, null, 1)], + // [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46] + // +- Project [key#17L, + // value#18, + // (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L] + // +- Subquery t1 + // +- Relation[key#17L,value#18] ParquetRelation + // Converted SQL: + // SELECT count( 1) AS `cnt`, + // (`t1`.`key` % CAST(5 AS BIGINT)), + // grouping_id() AS `_c2` + // FROM `default`.`t1` + // GROUP BY (`t1`.`key` % CAST(5 AS BIGINT)) + // GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ()) + checkHiveQl( + "SELECT count(*) as cnt, key % 5, grouping_id() FROM t1 GROUP BY key % 5 WITH ROLLUP") + checkHiveQl( + "SELECT count(*) as cnt, key % 5, grouping_id() FROM t1 GROUP BY key % 5 WITH CUBE") + } + + test("rollup/cube #2") { + checkHiveQl("SELECT key, value, count(value) FROM t1 GROUP BY key, value WITH ROLLUP") + checkHiveQl("SELECT key, value, count(value) FROM t1 GROUP BY key, value WITH CUBE") + } + + test("rollup/cube #3") { + checkHiveQl( + "SELECT key, count(value), grouping_id() FROM t1 GROUP BY key, value WITH ROLLUP") + checkHiveQl( + "SELECT key, count(value), grouping_id() FROM t1 GROUP BY key, value WITH CUBE") + } + + test("rollup/cube #4") { + checkHiveQl( + s""" + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM t1 + |GROUP BY key % 5, key - 5 WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM t1 + |GROUP BY key % 5, key - 5 WITH CUBE + """.stripMargin) + } + + test("rollup/cube #5") { + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key%2, key - 5 FROM t1) t GROUP BY key%5, key-5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key%2, key - 5 FROM t1) t GROUP BY key%5, key-5 + |WITH CUBE + """.stripMargin) + } + + test("rollup/cube #6") { + checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY CUBE(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(a) FROM t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(a) FROM t2 GROUP BY CUBE(a, b) ORDER BY a, b") + checkHiveQl("SELECT a + b, b, sum(a - b) FROM t2 GROUP BY a + b, b WITH ROLLUP") + checkHiveQl("SELECT a + b, b, sum(a - b) FROM t2 GROUP BY a + b, b WITH CUBE") + } + + test("grouping sets #1") { + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key % 2, key - 5 FROM t1) t GROUP BY key % 5, key - 5 + |GROUPING SETS (key % 5, key - 5) + """.stripMargin) + } + + test("grouping sets #2") { + checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b") + } + test("cluster by") { checkHiveQl("SELECT id FROM t0 CLUSTER BY id") } From a4083251381a033960f76b72b477adabac024faf Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Feb 2016 09:19:23 -0800 Subject: [PATCH 02/15] override the pretty name of groupingid --- .../apache/spark/sql/catalyst/expressions/grouping.scala | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index be9a46f70bc1a..437e417266fb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -63,11 +63,5 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une override def children: Seq[Expression] = groupByExprs override def dataType: DataType = IntegerType override def nullable: Boolean = false - - // TODO: remove this when SPARK-12799 is resolved. That will provide a general to-sql solution - // for all the expressions. - override def sql: String = { - val childrenSQL = children.map(_.sql).mkString(", ") - s"grouping_id($childrenSQL)" - } + override def prettyName: String = "grouping_id" } From 351421340247b9d75463bd4e384bfb86b21ff93a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Feb 2016 20:40:35 -0800 Subject: [PATCH 03/15] fixed the test case --- python/pyspark/sql/functions.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fdae05d98ceb6..3de4f1deac525 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -348,13 +348,13 @@ def grouping_id(*cols): grouping columns). >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() - +-----+------------+--------+ - | name|groupingid()|sum(age)| - +-----+------------+--------+ - | null| 1| 7| - |Alice| 0| 2| - | Bob| 0| 5| - +-----+------------+--------+ + +-----+-------------+--------+ + | name|grouping_id()|sum(age)| + +-----+-------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+-------------+--------+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) From 6a3659361570c21be51466f4f4a3bc5a70386d13 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Feb 2016 00:36:38 -0800 Subject: [PATCH 04/15] support grouping() --- .../apache/spark/sql/hive/SQLBuilder.scala | 34 +++++++++++++------ .../sql/hive/LogicalPlanToSQLSuite.scala | 22 +++++++----- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 3c812a5f7e56b..c966a84b6c514 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.types.{ByteType, IntegerType} /** * A builder class used to convert a resolved logical plan into a SQL query string. Note that this @@ -198,17 +199,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case a @ Alias(child, name) if !child.isInstanceOf[AttributeReference] => (a.toAttribute, a) }) - val aggExprs = plan.aggregateExpressions.map{ - // VirtualColumn.groupingIdName is added by Analyzer, and thus remove it. - case a @ Alias(child: AttributeReference, name) - if child.name == VirtualColumn.groupingIdName => - Alias(GroupingID(Nil), name)() - case a @ Alias(child: AttributeReference, name) if aliasMap.contains(child) => - aliasMap(child).child - case o => o - } - val groupingExprs = plan.groupingExpressions.filterNot { + // VirtualColumn.groupingIdName is added by Analyzer, and thus remove it. case a: NamedExpression => a.name == VirtualColumn.groupingIdName case o => false }.map { @@ -227,6 +219,28 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case o => o }) + val aggExprs = plan.aggregateExpressions.map { + case a @ Alias(child: AttributeReference, name) + if child.name == VirtualColumn.groupingIdName => + // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. + Alias(GroupingID(Nil), name)() + case a @ Alias(_ @ Cast(BitwiseAnd( + ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), + Literal(1, IntegerType)), ByteType), name) + if ar.name == VirtualColumn.groupingIdName => + // for converting an expression to its original SQL format grouping(col) + val idx = groupingExprs.length - 1 - value.asInstanceOf[Int] + val groupingCol = groupingExprs.lift(idx) + if (groupingCol.isDefined) { + Grouping(groupingCol.get) + } else { + throw new UnsupportedOperationException(s"unsupported operator $a") + } + case a @ Alias(child: AttributeReference, name) if aliasMap.contains(child) => + aliasMap(child).child + case o => o + } + build( "SELECT", aggExprs.map(_.sql).mkString(", "), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index a7963ca2c0012..45a38ea13aec0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -185,23 +185,23 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { """.stripMargin) checkHiveQl( s""" - |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM t1 - |GROUP BY key % 5, key - 5 WITH CUBE + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM t1 + |GROUP BY key % 5, key - 5 WITH CUBE """.stripMargin) } test("rollup/cube #5") { checkHiveQl( s""" - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM t1) t GROUP BY key%5, key-5 - |WITH ROLLUP + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 + |FROM (SELECT key, key%2, key - 5 FROM t1) t GROUP BY key%5, key-5 + |WITH ROLLUP """.stripMargin) checkHiveQl( s""" - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM t1) t GROUP BY key%5, key-5 - |WITH CUBE + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 + |FROM (SELECT key, key % 2, key - 5 FROM t1) t GROUP BY key % 5, key - 5 + |WITH CUBE """.stripMargin) } @@ -214,6 +214,12 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT a + b, b, sum(a - b) FROM t2 GROUP BY a + b, b WITH CUBE") } + test("rollup/cube #7") { + checkHiveQl("SELECT a, b, grouping_id(a, b) FROM t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping(b) FROM t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping(a) FROM t2 GROUP BY cube(a, b)") + } + test("grouping sets #1") { checkHiveQl( s""" From 1aef8494f9e131a584db5a93d9b20f78b5abe44c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Feb 2016 00:54:13 -0800 Subject: [PATCH 05/15] style fix. --- .../sql/hive/LogicalPlanToSQLSuite.scala | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 7d94671b0ba0d..611268adcc850 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -199,32 +199,32 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // GROUP BY (`t1`.`key` % CAST(5 AS BIGINT)) // GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ()) checkHiveQl( - "SELECT count(*) as cnt, key % 5, grouping_id() FROM t1 GROUP BY key % 5 WITH ROLLUP") + "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP") checkHiveQl( - "SELECT count(*) as cnt, key % 5, grouping_id() FROM t1 GROUP BY key % 5 WITH CUBE") + "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE") } test("rollup/cube #2") { - checkHiveQl("SELECT key, value, count(value) FROM t1 GROUP BY key, value WITH ROLLUP") - checkHiveQl("SELECT key, value, count(value) FROM t1 GROUP BY key, value WITH CUBE") + checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP") + checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE") } test("rollup/cube #3") { checkHiveQl( - "SELECT key, count(value), grouping_id() FROM t1 GROUP BY key, value WITH ROLLUP") + "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP") checkHiveQl( - "SELECT key, count(value), grouping_id() FROM t1 GROUP BY key, value WITH CUBE") + "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE") } test("rollup/cube #4") { checkHiveQl( s""" - |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM t1 + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 |GROUP BY key % 5, key - 5 WITH ROLLUP """.stripMargin) checkHiveQl( s""" - |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM t1 + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 |GROUP BY key % 5, key - 5 WITH CUBE """.stripMargin) } @@ -233,45 +233,48 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl( s""" |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 - |FROM (SELECT key, key%2, key - 5 FROM t1) t GROUP BY key%5, key-5 + |FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5 |WITH ROLLUP """.stripMargin) checkHiveQl( s""" |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 - |FROM (SELECT key, key % 2, key - 5 FROM t1) t GROUP BY key % 5, key - 5 + |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 |WITH CUBE """.stripMargin) } test("rollup/cube #6") { - checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") - checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY CUBE(a, b) ORDER BY a, b") - checkHiveQl("SELECT a, b, sum(a) FROM t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") - checkHiveQl("SELECT a, b, sum(a) FROM t2 GROUP BY CUBE(a, b) ORDER BY a, b") - checkHiveQl("SELECT a + b, b, sum(a - b) FROM t2 GROUP BY a + b, b WITH ROLLUP") - checkHiveQl("SELECT a + b, b, sum(a - b) FROM t2 GROUP BY a + b, b WITH CUBE") + checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b") + checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP") + checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE") } test("rollup/cube #7") { - checkHiveQl("SELECT a, b, grouping_id(a, b) FROM t2 GROUP BY cube(a, b)") - checkHiveQl("SELECT a, b, grouping(b) FROM t2 GROUP BY cube(a, b)") - checkHiveQl("SELECT a, b, grouping(a) FROM t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)") } test("grouping sets #1") { checkHiveQl( s""" |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key % 2, key - 5 FROM t1) t GROUP BY key % 5, key - 5 + |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 |GROUPING SETS (key % 5, key - 5) """.stripMargin) } test("grouping sets #2") { - checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b") - checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b") - checkHiveQl("SELECT a, b, sum(c) FROM t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b") } test("cluster by") { From 6f79df19b57df66b810067156d271b5140df6543 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 2 Mar 2016 10:42:06 -0800 Subject: [PATCH 06/15] resolve comments. --- .../apache/spark/sql/hive/SQLBuilder.scala | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 9987ec859156d..5deadbcc48bf3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -219,6 +219,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi plan: Aggregate, expand: Expand, project: Project): String = { + // The last column of Expand is always grouping ID + val gid = expand.output.last + // In cube/rollup/groupingsets, Analyzer creates new aliases for all group by expressions. // Since conversion from attribute back SQL ignore expression IDs, the alias of attribute // references are ignored in aliasMap @@ -228,7 +231,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val groupingExprs = plan.groupingExpressions.filterNot { // VirtualColumn.groupingIdName is added by Analyzer, and thus remove it. - case a: NamedExpression => a.name == VirtualColumn.groupingIdName + case a: AttributeReference => a == gid case o => false }.map { case a: AttributeReference if aliasMap.contains(a) => aliasMap(a).child @@ -237,8 +240,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val groupingSQL = groupingExprs.map(_.sql).mkString(", ") - val groupingSet = expand.projections.map(_.filter { - case _: Literal => false + val groupingSet = expand.projections.map(_.dropRight(1).filter { case e: Expression if plan.groupingExpressions.exists(_.semanticEquals(e)) => true case _ => false }.map { @@ -246,28 +248,32 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case o => o }) - val aggExprs = plan.aggregateExpressions.map { - case a @ Alias(child: AttributeReference, name) - if child.name == VirtualColumn.groupingIdName => - // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. - Alias(GroupingID(Nil), name)() - case a @ Alias(_ @ Cast(BitwiseAnd( - ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), - Literal(1, IntegerType)), ByteType), name) - if ar.name == VirtualColumn.groupingIdName => - // for converting an expression to its original SQL format grouping(col) - val idx = groupingExprs.length - 1 - value.asInstanceOf[Int] - val groupingCol = groupingExprs.lift(idx) - if (groupingCol.isDefined) { - Grouping(groupingCol.get) - } else { - throw new UnsupportedOperationException(s"unsupported operator $a") - } - case a @ Alias(child: AttributeReference, name) if aliasMap.contains(child) => - aliasMap(child).child - case o => o + val aggExprs = plan.aggregateExpressions.map { case expr => + expr.transformDown { + case a @ Alias(child: AttributeReference, name) if child eq gid => + // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. + Alias(GroupingID(Nil), name)() + case a @ Alias(_ @ Cast(BitwiseAnd( + ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), + Literal(1, IntegerType)), ByteType), name) if ar == gid => + // for converting an expression to its original SQL format grouping(col) + val idx = groupingExprs.length - 1 - value.asInstanceOf[Int] + val groupingCol = groupingExprs.lift(idx) + if (groupingCol.isDefined) { + Grouping(groupingCol.get) + } else { + throw new UnsupportedOperationException(s"unsupported operator $a") + } + case a @ Alias(child: AttributeReference, name) if aliasMap.contains(child) => + aliasMap(child).child + case o => o + } } + val groupingSetSQL = + "GROUPING SETS(" + + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" + build( "SELECT", aggExprs.map(_.sql).mkString(", "), @@ -275,8 +281,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi toSQL(project.child), if (groupingSQL.isEmpty) "" else "GROUP BY", groupingSQL, - "GROUPING SETS", - "(" + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" + groupingSetSQL ) } From 37a9d8d4c86d1207b07bef4186fc23ad5bf3d75c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 2 Mar 2016 20:04:47 -0800 Subject: [PATCH 07/15] resolve comments. --- .../apache/spark/sql/hive/SQLBuilder.scala | 19 ++++++++----------- .../sql/hive/LogicalPlanToSQLSuite.scala | 7 +++++++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 5deadbcc48bf3..95dff0137970d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -219,6 +219,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi plan: Aggregate, expand: Expand, project: Project): String = { + require(plan.groupingExpressions.length > 1) + // The last column of Expand is always grouping ID val gid = expand.output.last @@ -226,18 +228,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // Since conversion from attribute back SQL ignore expression IDs, the alias of attribute // references are ignored in aliasMap val aliasMap = AttributeMap(project.projectList.collect { - case a @ Alias(child, name) if !child.isInstanceOf[AttributeReference] => (a.toAttribute, a) + case a @ Alias(child, name) => (a.toAttribute, a) }) - val groupingExprs = plan.groupingExpressions.filterNot { - // VirtualColumn.groupingIdName is added by Analyzer, and thus remove it. - case a: AttributeReference => a == gid - case o => false - }.map { - case a: AttributeReference if aliasMap.contains(a) => aliasMap(a).child - case o => o - } - + val groupByAttributes = plan.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) + val groupByAttrMap = AttributeMap(groupByAttributes.zip( + project.projectList.drop(project.child.output.length).map(_.asInstanceOf[Alias].child))) + val groupingExprs = groupByAttrMap.values.toArray val groupingSQL = groupingExprs.map(_.sql).mkString(", ") val groupingSet = expand.projections.map(_.dropRight(1).filter { @@ -279,7 +276,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi aggExprs.map(_.sql).mkString(", "), if (plan.child == OneRowRelation) "" else "FROM", toSQL(project.child), - if (groupingSQL.isEmpty) "" else "GROUP BY", + "GROUP BY", groupingSQL, groupingSetSQL ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index e0b0ff193f995..34a044906a220 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -275,6 +275,13 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b") checkHiveQl( "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b") + checkHiveQl( + s""" + |SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b + |GROUPING SETS ((), (a), (a, b)) ORDER BY a, b + """.stripMargin) } test("cluster by") { From 640e45cdbbe0d08acb26af715181dc4d16dd7893 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 2 Mar 2016 20:18:02 -0800 Subject: [PATCH 08/15] resolve comments. --- .../scala/org/apache/spark/sql/hive/SQLBuilder.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 95dff0137970d..7648035d91558 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -232,8 +232,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi }) val groupByAttributes = plan.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) - val groupByAttrMap = AttributeMap(groupByAttributes.zip( - project.projectList.drop(project.child.output.length).map(_.asInstanceOf[Alias].child))) + val groupByExprs = + project.projectList.drop(project.child.output.length).map(_.asInstanceOf[Alias].child) + val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) val groupingExprs = groupByAttrMap.values.toArray val groupingSQL = groupingExprs.map(_.sql).mkString(", ") @@ -261,8 +262,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } else { throw new UnsupportedOperationException(s"unsupported operator $a") } - case a @ Alias(child: AttributeReference, name) if aliasMap.contains(child) => + case a @ Alias(child: AttributeReference, _) if aliasMap.contains(child) => aliasMap(child).child + case ar: AttributeReference if aliasMap.contains(ar) => + aliasMap(ar).child case o => o } } From 749be1b2de07603a7fc9961872954c24a584a1fb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 2 Mar 2016 21:48:22 -0800 Subject: [PATCH 09/15] resolve comments. --- .../apache/spark/sql/hive/SQLBuilder.scala | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 7648035d91558..77195e24458ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -224,27 +224,22 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // The last column of Expand is always grouping ID val gid = expand.output.last - // In cube/rollup/groupingsets, Analyzer creates new aliases for all group by expressions. - // Since conversion from attribute back SQL ignore expression IDs, the alias of attribute - // references are ignored in aliasMap - val aliasMap = AttributeMap(project.projectList.collect { - case a @ Alias(child, name) => (a.toAttribute, a) - }) - val groupByAttributes = plan.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) val groupByExprs = project.projectList.drop(project.child.output.length).map(_.asInstanceOf[Alias].child) val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) val groupingExprs = groupByAttrMap.values.toArray + val groupingSQL = groupingExprs.map(_.sql).mkString(", ") - val groupingSet = expand.projections.map(_.dropRight(1).filter { - case e: Expression if plan.groupingExpressions.exists(_.semanticEquals(e)) => true - case _ => false - }.map { - case a: AttributeReference if aliasMap.contains(a) => aliasMap(a).child - case o => o - }) + val groupingSet = expand.projections.map { project => + project.dropRight(1).collect { + case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) + } + } + val groupingSetSQL = + "GROUPING SETS(" + + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" val aggExprs = plan.aggregateExpressions.map { case expr => expr.transformDown { @@ -262,18 +257,14 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } else { throw new UnsupportedOperationException(s"unsupported operator $a") } - case a @ Alias(child: AttributeReference, _) if aliasMap.contains(child) => - aliasMap(child).child - case ar: AttributeReference if aliasMap.contains(ar) => - aliasMap(ar).child + case a @ Alias(child: AttributeReference, _) if groupByAttrMap.contains(child) => + groupByAttrMap(child) + case ar: AttributeReference if groupByAttrMap.contains(ar) => + groupByAttrMap(ar) case o => o } } - val groupingSetSQL = - "GROUPING SETS(" + - groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" - build( "SELECT", aggExprs.map(_.sql).mkString(", "), From ae768fe7a27262e58913e50b654bfc9f23199b12 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 2 Mar 2016 22:02:13 -0800 Subject: [PATCH 10/15] resolve comments. --- .../main/scala/org/apache/spark/sql/hive/SQLBuilder.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 77195e24458ec..ed863af2e5af3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -228,9 +228,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val groupByExprs = project.projectList.drop(project.child.output.length).map(_.asInstanceOf[Alias].child) val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) - val groupingExprs = groupByAttrMap.values.toArray - - val groupingSQL = groupingExprs.map(_.sql).mkString(", ") + val groupingSQL = groupByExprs.map(_.sql).mkString(", ") val groupingSet = expand.projections.map { project => project.dropRight(1).collect { @@ -250,8 +248,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), Literal(1, IntegerType)), ByteType), name) if ar == gid => // for converting an expression to its original SQL format grouping(col) - val idx = groupingExprs.length - 1 - value.asInstanceOf[Int] - val groupingCol = groupingExprs.lift(idx) + val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] + val groupingCol = groupByExprs.lift(idx) if (groupingCol.isDefined) { Grouping(groupingCol.get) } else { From 6cea658bdb76274edd22c3187e1885ae8f257b86 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 3 Mar 2016 10:03:26 -0800 Subject: [PATCH 11/15] resolve comments. --- .../apache/spark/sql/hive/SQLBuilder.scala | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index ed863af2e5af3..5b14c09d20b87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -86,7 +86,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Project => projectToSQL(p, isDistinct = false) - case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) => + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) + if sameOutput(e.output, + p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) => groupingSetToSQL(a, e, p) case p: Aggregate => @@ -185,6 +187,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi throw new UnsupportedOperationException(s"unsupported plan $node") } + private def sameOutput(left: Seq[Attribute], right: Seq[Attribute]): Boolean = + left.forall(a => right.exists(_.semanticEquals(a))) && + right.forall(a => left.exists(_.semanticEquals(a))) + /** * Turns a bunch of string segments into a single string and separate each segment by a space. * The segments are trimmed so only a single space appears in the separation. @@ -231,6 +237,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val groupingSQL = groupByExprs.map(_.sql).mkString(", ") val groupingSet = expand.projections.map { project => + // Assumption: expand.projections are composed of + // 1) the original output (project.child.output), + // 2) group by attributes(or null literal) + // 3) gid, which is always the last one in each project project.dropRight(1).collect { case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) } @@ -241,9 +251,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val aggExprs = plan.aggregateExpressions.map { case expr => expr.transformDown { - case a @ Alias(child: AttributeReference, name) if child eq gid => - // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. - Alias(GroupingID(Nil), name)() + // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. + case ar: AttributeReference if ar eq gid => GroupingID(Nil) case a @ Alias(_ @ Cast(BitwiseAnd( ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), Literal(1, IntegerType)), ByteType), name) if ar == gid => @@ -255,11 +264,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } else { throw new UnsupportedOperationException(s"unsupported operator $a") } - case a @ Alias(child: AttributeReference, _) if groupByAttrMap.contains(child) => - groupByAttrMap(child) + case a @ Alias(ar: AttributeReference, _) if groupByAttrMap.contains(ar) => + groupByAttrMap(ar) case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) - case o => o } } From b1925e5feb83228db8ed8502306368a2d979e56b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 3 Mar 2016 10:12:11 -0800 Subject: [PATCH 12/15] cleaned the comment. --- .../src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index b41ca4faf85a8..1f1776d3a6ee8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -259,9 +259,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val groupingSet = expand.projections.map { project => // Assumption: expand.projections are composed of - // 1) the original output (project.child.output), + // 1) the original output (Project's child.output), // 2) group by attributes(or null literal) - // 3) gid, which is always the last one in each project + // 3) gid, which is always the last one in each project in Expand project.dropRight(1).collect { case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) } From 6f609fb2d844e2aaf4c809ef8c0fcd9e6eca38bb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 3 Mar 2016 17:54:58 -0800 Subject: [PATCH 13/15] address comments. --- .../apache/spark/sql/hive/SQLBuilder.scala | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 1f1776d3a6ee8..0d9a68342c643 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -208,9 +208,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi throw new UnsupportedOperationException(s"unsupported plan $node") } - private def sameOutput(left: Seq[Attribute], right: Seq[Attribute]): Boolean = - left.forall(a => right.exists(_.semanticEquals(a))) && - right.forall(a => left.exists(_.semanticEquals(a))) + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) /** * Turns a bunch of string segments into a single string and separate each segment by a space. @@ -243,26 +243,34 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } private def groupingSetToSQL( - plan: Aggregate, + agg: Aggregate, expand: Expand, project: Project): String = { - require(plan.groupingExpressions.length > 1) + assert(agg.groupingExpressions.length > 1) // The last column of Expand is always grouping ID val gid = expand.output.last - val groupByAttributes = plan.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) - val groupByExprs = - project.projectList.drop(project.child.output.length).map(_.asInstanceOf[Alias].child) - val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + val numOriginalOutput = project.child.output.length + // Assumption: Aggregate's groupingExpressions is composed of + // 1) the group by attributes' aliases + // 2) gid, which is always the last one + val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) + // Assumption: Project's projectList is composed of + // 1) the original output (Project's child.output), + // 2) the aliases of the original group by attributes, which could be expressions + val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) val groupingSQL = groupByExprs.map(_.sql).mkString(", ") + // a map from the alias name to the original group by expresions/attributes + val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + val groupingSet = expand.projections.map { project => - // Assumption: expand.projections are composed of + // Assumption: expand.projections is composed of // 1) the original output (Project's child.output), // 2) group by attributes(or null literal) // 3) gid, which is always the last one in each project in Expand - project.dropRight(1).collect { + project.drop(numOriginalOutput).dropRight(1).collect { case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) } } @@ -270,13 +278,14 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi "GROUPING SETS(" + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" - val aggExprs = plan.aggregateExpressions.map { case expr => + val aggExprs = agg.aggregateExpressions.map { case expr => expr.transformDown { // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. - case ar: AttributeReference if ar eq gid => GroupingID(Nil) - case a @ Alias(_ @ Cast(BitwiseAnd( + case ar: AttributeReference if ar == gid => GroupingID(Nil) + case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) + case a @ Cast(BitwiseAnd( ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), - Literal(1, IntegerType)), ByteType), name) if ar == gid => + Literal(1, IntegerType)), ByteType) if ar == gid => // for converting an expression to its original SQL format grouping(col) val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] val groupingCol = groupByExprs.lift(idx) @@ -285,17 +294,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } else { throw new UnsupportedOperationException(s"unsupported operator $a") } - case a @ Alias(ar: AttributeReference, _) if groupByAttrMap.contains(ar) => - groupByAttrMap(ar) - case ar: AttributeReference if groupByAttrMap.contains(ar) => - groupByAttrMap(ar) } } build( "SELECT", aggExprs.map(_.sql).mkString(", "), - if (plan.child == OneRowRelation) "" else "FROM", + if (agg.child == OneRowRelation) "" else "FROM", toSQL(project.child), "GROUP BY", groupingSQL, From 9eaca515a3a86f07ed4ca85ba6da080ad605d1c0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 3 Mar 2016 19:15:43 -0800 Subject: [PATCH 14/15] address comments. --- .../apache/spark/sql/hive/SQLBuilder.scala | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 0d9a68342c643..f6d657abc96f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -107,9 +107,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Project => projectToSQL(p, isDistinct = false) - case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) - if sameOutput(e.output, - p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) => + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) => groupingSetToSQL(a, e, p) case p: Aggregate => @@ -208,10 +206,6 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi throw new UnsupportedOperationException(s"unsupported plan $node") } - private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) - /** * Turns a bunch of string segments into a single string and separate each segment by a space. * The segments are trimmed so only a single space appears in the separation. @@ -242,6 +236,16 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) } + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + + private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = { + assert(a.child == e && e.child == p) + a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && + sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) + } + private def groupingSetToSQL( agg: Aggregate, expand: Expand, @@ -253,16 +257,16 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val numOriginalOutput = project.child.output.length // Assumption: Aggregate's groupingExpressions is composed of - // 1) the group by attributes' aliases + // 1) the attributes of aliased group by expressions // 2) gid, which is always the last one val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) // Assumption: Project's projectList is composed of // 1) the original output (Project's child.output), - // 2) the aliases of the original group by attributes, which could be expressions + // 2) the aliased group by expressions. val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) val groupingSQL = groupByExprs.map(_.sql).mkString(", ") - // a map from the alias name to the original group by expresions/attributes + // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) val groupingSet = expand.projections.map { project => From b8786b29c8e8058e6c765dc8cbef62657fa17a5a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 4 Mar 2016 21:48:38 -0800 Subject: [PATCH 15/15] address comments. --- .../apache/spark/sql/hive/SQLBuilder.scala | 11 ++---- .../sql/hive/LogicalPlanToSQLSuite.scala | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index f6d657abc96f6..7ba1fc82f7d67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -269,7 +269,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) - val groupingSet = expand.projections.map { project => + val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project => // Assumption: expand.projections is composed of // 1) the original output (Project's child.output), // 2) group by attributes(or null literal) @@ -288,16 +288,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case ar: AttributeReference if ar == gid => GroupingID(Nil) case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) case a @ Cast(BitwiseAnd( - ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), + ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)), Literal(1, IntegerType)), ByteType) if ar == gid => // for converting an expression to its original SQL format grouping(col) val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] - val groupingCol = groupByExprs.lift(idx) - if (groupingCol.isDefined) { - Grouping(groupingCol.get) - } else { - throw new UnsupportedOperationException(s"unsupported operator $a") - } + groupByExprs.lift(idx).map(Grouping).getOrElse(a) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 282f897c8d2f6..f457d43e19a50 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -300,6 +300,42 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)") } + test("rollup/cube #8") { + // grouping_id() is part of another expression + checkHiveQl( + s""" + |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid + |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid + |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 + |WITH CUBE + """.stripMargin) + } + + test("rollup/cube #9") { + // self join is used as the child node of ROLLUP/CUBE with replaced quantifiers + checkHiveQl( + s""" + |SELECT t.key - 5, cnt, SUM(cnt) + |FROM (SELECT x.key, COUNT(*) as cnt + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t + |GROUP BY cnt, t.key - 5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT t.key - 5, cnt, SUM(cnt) + |FROM (SELECT x.key, COUNT(*) as cnt + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t + |GROUP BY cnt, t.key - 5 + |WITH CUBE + """.stripMargin) + } + test("grouping sets #1") { checkHiveQl( s"""