Skip to content

Commit

Permalink
Merge branch 'groupingSetsToSQLNew' into groupingSetsToSQLNewNewNew
Browse files Browse the repository at this point in the history
# Conflicts:
#	sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
  • Loading branch information
gatorsmile committed Mar 5, 2016
2 parents 59daa48 + b8786b2 commit 385c0d9
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 8 deletions.
14 changes: 7 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,13 @@ def grouping_id(*cols):
grouping columns).
>>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()
+-----+------------+--------+
| name|groupingid()|sum(age)|
+-----+------------+--------+
| null| 1| 7|
|Alice| 0| 2|
| Bob| 0| 5|
+-----+------------+--------+
+-----+-------------+--------+
| name|grouping_id()|sum(age)|
+-----+-------------+--------+
| null| 1| 7|
|Alice| 0| 2|
| Bob| 0| 5|
+-----+-------------+--------+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,5 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une
override def children: Seq[Expression] = groupByExprs
override def dataType: DataType = IntegerType
override def nullable: Boolean = false
override def prettyName: String = "grouping_id"
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.types.{DataType, NullType}
import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType}

/**
* A place holder for generated SQL for subquery expression.
Expand Down Expand Up @@ -118,6 +118,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: Project =>
projectToSQL(p, isDistinct = false)

case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
groupingSetToSQL(a, e, p)

case p: Aggregate =>
aggregateToSQL(p)

Expand Down Expand Up @@ -244,6 +247,77 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
)
}

private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))

private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
assert(a.child == e && e.child == p)
a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
}

private def groupingSetToSQL(
agg: Aggregate,
expand: Expand,
project: Project): String = {
assert(agg.groupingExpressions.length > 1)

// The last column of Expand is always grouping ID
val gid = expand.output.last

val numOriginalOutput = project.child.output.length
// Assumption: Aggregate's groupingExpressions is composed of
// 1) the attributes of aliased group by expressions
// 2) gid, which is always the last one
val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
// Assumption: Project's projectList is composed of
// 1) the original output (Project's child.output),
// 2) the aliased group by expressions.
val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")

// a map from group by attributes to the original group by expressions.
val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))

val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
// Assumption: expand.projections is composed of
// 1) the original output (Project's child.output),
// 2) group by attributes(or null literal)
// 3) gid, which is always the last one in each project in Expand
project.drop(numOriginalOutput).dropRight(1).collect {
case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
}
}
val groupingSetSQL =
"GROUPING SETS(" +
groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"

val aggExprs = agg.aggregateExpressions.map { case expr =>
expr.transformDown {
// grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
case ar: AttributeReference if ar == gid => GroupingID(Nil)
case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
case a @ Cast(BitwiseAnd(
ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)),
Literal(1, IntegerType)), ByteType) if ar == gid =>
// for converting an expression to its original SQL format grouping(col)
val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
groupByExprs.lift(idx).map(Grouping).getOrElse(a)
}
}

build(
"SELECT",
aggExprs.map(_.sql).mkString(", "),
if (agg.child == OneRowRelation) "" else "FROM",
toSQL(project.child),
"GROUP BY",
groupingSQL,
groupingSetSQL
)
}

object Canonicalizer extends RuleExecutor[LogicalPlan] {
override protected def batches: Seq[Batch] = Seq(
Batch("Canonicalizer", FixedPoint(100),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,149 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkHiveQl("SELECT DISTINCT id FROM parquet_t0")
}

test("rollup/cube #1") {
// Original logical plan:
// Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46],
// [(count(1),mode=Complete,isDistinct=false) AS cnt#43L,
// (key#17L % cast(5 as bigint))#47L AS _c1#45L,
// grouping__id#46 AS _c2#44]
// +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0),
// List(key#17L, value#18, null, 1)],
// [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46]
// +- Project [key#17L,
// value#18,
// (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L]
// +- Subquery t1
// +- Relation[key#17L,value#18] ParquetRelation
// Converted SQL:
// SELECT count( 1) AS `cnt`,
// (`t1`.`key` % CAST(5 AS BIGINT)),
// grouping_id() AS `_c2`
// FROM `default`.`t1`
// GROUP BY (`t1`.`key` % CAST(5 AS BIGINT))
// GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ())
checkHiveQl(
"SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP")
checkHiveQl(
"SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE")
}

test("rollup/cube #2") {
checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP")
checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE")
}

test("rollup/cube #3") {
checkHiveQl(
"SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP")
checkHiveQl(
"SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE")
}

test("rollup/cube #4") {
checkHiveQl(
s"""
|SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
|GROUP BY key % 5, key - 5 WITH ROLLUP
""".stripMargin)
checkHiveQl(
s"""
|SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
|GROUP BY key % 5, key - 5 WITH CUBE
""".stripMargin)
}

test("rollup/cube #5") {
checkHiveQl(
s"""
|SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
|FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5
|WITH ROLLUP
""".stripMargin)
checkHiveQl(
s"""
|SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
|FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
|WITH CUBE
""".stripMargin)
}

test("rollup/cube #6") {
checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b")
checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b")
checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b")
checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b")
checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP")
checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE")
}

test("rollup/cube #7") {
checkHiveQl("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)")
checkHiveQl("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)")
checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)")
}

test("rollup/cube #8") {
// grouping_id() is part of another expression
checkHiveQl(
s"""
|SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
|FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
|WITH ROLLUP
""".stripMargin)
checkHiveQl(
s"""
|SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
|FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
|WITH CUBE
""".stripMargin)
}

test("rollup/cube #9") {
// self join is used as the child node of ROLLUP/CUBE with replaced quantifiers
checkHiveQl(
s"""
|SELECT t.key - 5, cnt, SUM(cnt)
|FROM (SELECT x.key, COUNT(*) as cnt
|FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
|GROUP BY cnt, t.key - 5
|WITH ROLLUP
""".stripMargin)
checkHiveQl(
s"""
|SELECT t.key - 5, cnt, SUM(cnt)
|FROM (SELECT x.key, COUNT(*) as cnt
|FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
|GROUP BY cnt, t.key - 5
|WITH CUBE
""".stripMargin)
}

test("grouping sets #1") {
checkHiveQl(
s"""
|SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3
|FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
|GROUPING SETS (key % 5, key - 5)
""".stripMargin)
}

test("grouping sets #2") {
checkHiveQl(
"SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b")
checkHiveQl(
"SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b")
checkHiveQl(
"SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b")
checkHiveQl(
"SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b")
checkHiveQl(
s"""
|SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b
|GROUPING SETS ((), (a), (a, b)) ORDER BY a, b
""".stripMargin)
}

test("cluster by") {
checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id")
}
Expand Down

0 comments on commit 385c0d9

Please sign in to comment.