Skip to content

Commit

Permalink
address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
gatorsmile committed Mar 4, 2016
1 parent b1925e5 commit 6f609fb
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
throw new UnsupportedOperationException(s"unsupported plan $node")
}

private def sameOutput(left: Seq[Attribute], right: Seq[Attribute]): Boolean =
left.forall(a => right.exists(_.semanticEquals(a))) &&
right.forall(a => left.exists(_.semanticEquals(a)))
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))

/**
* Turns a bunch of string segments into a single string and separate each segment by a space.
Expand Down Expand Up @@ -243,40 +243,49 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
}

private def groupingSetToSQL(
plan: Aggregate,
agg: Aggregate,
expand: Expand,
project: Project): String = {
require(plan.groupingExpressions.length > 1)
assert(agg.groupingExpressions.length > 1)

// The last column of Expand is always grouping ID
val gid = expand.output.last

val groupByAttributes = plan.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
val groupByExprs =
project.projectList.drop(project.child.output.length).map(_.asInstanceOf[Alias].child)
val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
val numOriginalOutput = project.child.output.length
// Assumption: Aggregate's groupingExpressions is composed of
// 1) the group by attributes' aliases
// 2) gid, which is always the last one
val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
// Assumption: Project's projectList is composed of
// 1) the original output (Project's child.output),
// 2) the aliases of the original group by attributes, which could be expressions
val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")

// a map from the alias name to the original group by expresions/attributes
val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))

val groupingSet = expand.projections.map { project =>
// Assumption: expand.projections are composed of
// Assumption: expand.projections is composed of
// 1) the original output (Project's child.output),
// 2) group by attributes(or null literal)
// 3) gid, which is always the last one in each project in Expand
project.dropRight(1).collect {
project.drop(numOriginalOutput).dropRight(1).collect {
case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
}
}
val groupingSetSQL =
"GROUPING SETS(" +
groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"

val aggExprs = plan.aggregateExpressions.map { case expr =>
val aggExprs = agg.aggregateExpressions.map { case expr =>
expr.transformDown {
// grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
case ar: AttributeReference if ar eq gid => GroupingID(Nil)
case a @ Alias(_ @ Cast(BitwiseAnd(
case ar: AttributeReference if ar == gid => GroupingID(Nil)
case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
case a @ Cast(BitwiseAnd(
ShiftRight(ar: AttributeReference, _ @ Literal(value: Any, IntegerType)),
Literal(1, IntegerType)), ByteType), name) if ar == gid =>
Literal(1, IntegerType)), ByteType) if ar == gid =>
// for converting an expression to its original SQL format grouping(col)
val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
val groupingCol = groupByExprs.lift(idx)
Expand All @@ -285,17 +294,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
} else {
throw new UnsupportedOperationException(s"unsupported operator $a")
}
case a @ Alias(ar: AttributeReference, _) if groupByAttrMap.contains(ar) =>
groupByAttrMap(ar)
case ar: AttributeReference if groupByAttrMap.contains(ar) =>
groupByAttrMap(ar)
}
}

build(
"SELECT",
aggExprs.map(_.sql).mkString(", "),
if (plan.child == OneRowRelation) "" else "FROM",
if (agg.child == OneRowRelation) "" else "FROM",
toSQL(project.child),
"GROUP BY",
groupingSQL,
Expand Down

0 comments on commit 6f609fb

Please sign in to comment.