diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index f6d657abc96f6..7ba1fc82f7d67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -269,7 +269,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) - val groupingSet = expand.projections.map { project => + val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project => // Assumption: expand.projections is composed of // 1) the original output (Project's child.output), // 2) group by attributes(or null literal) @@ -288,16 +288,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case ar: AttributeReference if ar == gid => GroupingID(Nil) case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) case a @ Cast(BitwiseAnd( - ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)), + ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)), Literal(1, IntegerType)), ByteType) if ar == gid => // for converting an expression to its original SQL format grouping(col) val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] - val groupingCol = groupByExprs.lift(idx) - if (groupingCol.isDefined) { - Grouping(groupingCol.get) - } else { - throw new UnsupportedOperationException(s"unsupported operator $a") - } + groupByExprs.lift(idx).map(Grouping).getOrElse(a) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 282f897c8d2f6..f457d43e19a50 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -300,6 +300,42 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)") } + test("rollup/cube #8") { + // grouping_id() is part of another expression + checkHiveQl( + s""" + |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid + |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid + |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 + |WITH CUBE + """.stripMargin) + } + + test("rollup/cube #9") { + // self join is used as the child node of ROLLUP/CUBE with replaced quantifiers + checkHiveQl( + s""" + |SELECT t.key - 5, cnt, SUM(cnt) + |FROM (SELECT x.key, COUNT(*) as cnt + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t + |GROUP BY cnt, t.key - 5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT t.key - 5, cnt, SUM(cnt) + |FROM (SELECT x.key, COUNT(*) as cnt + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t + |GROUP BY cnt, t.key - 5 + |WITH CUBE + """.stripMargin) + } + test("grouping sets #1") { checkHiveQl( s"""