Skip to content

Commit

Permalink
fix qualifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Mar 7, 2016
1 parent 3ce072b commit e037814
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ object SizeBasedWindowFunction {
the window partition.""")
case class RowNumber() extends RowNumberLike {
override val evaluateExpression = rowNumber
override def sql: String = "ROW_NUMBER()"
}

/**
Expand All @@ -487,6 +488,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
// return the same value for equal values in the partition.
override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)
override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType))
override def sql: String = "CUME_DIST()"
}

/**
Expand Down Expand Up @@ -625,6 +627,7 @@ abstract class RankLike extends AggregateWindowFunction {
case class Rank(children: Seq[Expression]) extends RankLike {
def this() = this(Nil)
override def withOrder(order: Seq[Expression]): Rank = Rank(order)
override def sql: String = "RANK()"
}

/**
Expand All @@ -649,6 +652,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike {
override val updateExpressions = increaseRank +: children
override val aggBufferAttributes = rank +: orderAttrs
override val initialValues = zero +: orderInit
override def sql: String = "DENSE_RANK()"
}

/**
Expand All @@ -675,4 +679,5 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase
override val evaluateExpression = If(GreaterThan(n, one),
Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)),
Literal(0.0d))
override def sql: String = "PERCENT_RANK()"
}
107 changes: 68 additions & 39 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive

import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.spark.Logging
Expand Down Expand Up @@ -107,7 +108,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
windowToSQL(w)

case Limit(limitExpr, child) =>
s"${toSQL(child)} LIMIT ${limitExpr.sql}"
s"${toSQL(child)} LIMIT ${exprToSQL(limitExpr, child)}"

case p: Sample if p.isTableSample =>
val fraction = math.min(100, math.max(0, (p.upperBound - p.lowerBound) * 100))
Expand All @@ -126,12 +127,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
build(toSQL(p.child), "TABLESAMPLE(" + fraction + " PERCENT)")
}

case p: Filter =>
val whereOrHaving = p.child match {
case Filter(condition, child) =>
val whereOrHaving = child match {
case _: Aggregate => "HAVING"
case _ => "WHERE"
}
build(toSQL(p.child), whereOrHaving, p.condition.sql)
build(toSQL(child), whereOrHaving, exprToSQL(condition, child))

case p @ Distinct(u: Union) if u.children.length > 1 =>
val childrenSql = u.children.map(c => s"(${toSQL(c)})")
Expand Down Expand Up @@ -166,7 +167,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
p.joinType.sql,
"JOIN",
toSQL(p.right),
p.condition.map(" ON " + _.sql).getOrElse(""))
p.condition.map(" ON " + exprToSQL(_, p)).getOrElse(""))

case p: MetastoreRelation =>
build(
Expand All @@ -176,20 +177,20 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi

case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
if orders.map(_.child) == partitionExprs =>
build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", "))
build(toSQL(child), "CLUSTER BY", exprsToSQL(partitionExprs, child))

case p: Sort =>
build(
toSQL(p.child),
if (p.global) "ORDER BY" else "SORT BY",
p.order.map(_.sql).mkString(", ")
exprsToSQL(p.order, p.child)
)

case p: RepartitionByExpression =>
build(
toSQL(p.child),
"DISTRIBUTE BY",
p.partitionExpressions.map(_.sql).mkString(", ")
exprsToSQL(p.partitionExpressions, p.child)
)

case OneRowRelation =>
Expand All @@ -207,11 +208,55 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
private def build(segments: String*): String =
segments.map(_.trim).filter(_.nonEmpty).mkString(" ")

private def updateQualifier(
expr: Expression,
qualifiers: Seq[(String, AttributeSet)]): Expression = {
if (qualifiers.isEmpty) {
expr
} else {
expr transform {
case a: Attribute =>
val index = qualifiers.indexWhere {
case (_, inputAttributes) => inputAttributes.contains(a)
}
if (index == -1) {
a
} else {
a.withQualifiers(qualifiers(index)._1 :: Nil)
}
}
}
}

private def findQualifiers(input: LogicalPlan): Seq[(String, AttributeSet)] = {
val results = mutable.ArrayBuffer.empty[(String, AttributeSet)]
val nodes = mutable.Stack(input)

while (nodes.nonEmpty) {
val node = nodes.pop()
node match {
case SubqueryAlias(alias, child) => results += alias -> child.outputSet
case _ => node.children.foreach(nodes.push)
}
}

results.toSeq
}

private def exprToSQL(e: Expression, input: LogicalPlan): String = {
updateQualifier(e, findQualifiers(input)).sql
}

private def exprsToSQL(exprs: Seq[Expression], input: LogicalPlan): String = {
val qualifiers = findQualifiers(input)
exprs.map(updateQualifier(_, qualifiers).sql).mkString(", ")
}

private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
build(
"SELECT",
if (isDistinct) "DISTINCT" else "",
plan.projectList.map(_.sql).mkString(", "),
exprsToSQL(plan.projectList, plan.child),
if (plan.child == OneRowRelation) "" else "FROM",
toSQL(plan.child)
)
Expand All @@ -221,7 +266,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ")
build(
"SELECT",
plan.aggregateExpressions.map(_.sql).mkString(", "),
exprsToSQL(plan.aggregateExpressions, plan.child),
if (plan.child == OneRowRelation) "" else "FROM",
toSQL(plan.child),
if (groupingSQL.isEmpty) "" else "GROUP BY",
Expand All @@ -244,11 +289,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
expand: Expand,
project: Project): String = {
assert(agg.groupingExpressions.length > 1)
val input = project.child

// The last column of Expand is always grouping ID
val gid = expand.output.last

val numOriginalOutput = project.child.output.length
val numOriginalOutput = input.output.length
// Assumption: Aggregate's groupingExpressions is composed of
// 1) the attributes of aliased group by expressions
// 2) gid, which is always the last one
Expand All @@ -257,7 +303,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
// 1) the original output (Project's child.output),
// 2) the aliased group by expressions.
val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
val groupingSQL = exprsToSQL(groupByExprs, input)

// a map from group by attributes to the original group by expressions.
val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
Expand All @@ -272,8 +318,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
}
}
val groupingSetSQL =
"GROUPING SETS(" +
groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
"GROUPING SETS(" + groupingSet.map(e => "(" + exprsToSQL(e, input) + ")").mkString(", ") + ")"

val aggExprs = agg.aggregateExpressions.map { case expr =>
expr.transformDown {
Expand All @@ -291,9 +336,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi

build(
"SELECT",
aggExprs.map(_.sql).mkString(", "),
exprsToSQL(aggExprs, input),
if (agg.child == OneRowRelation) "" else "FROM",
toSQL(project.child),
toSQL(input),
"GROUP BY",
groupingSQL,
groupingSetSQL
Expand All @@ -303,7 +348,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
private def windowToSQL(w: Window): String = {
build(
"SELECT",
(w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "),
exprsToSQL(w.child.output ++ w.windowExpressions, w.child),
if (w.child == OneRowRelation) "" else "FROM",
toSQL(w.child)
)
Expand Down Expand Up @@ -337,7 +382,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
// +- Aggregate ...
// +- MetastoreRelation default, src, None
case plan @ Project(_, Filter(_, _: Aggregate)) =>
wrapChildWithSubquery(plan)
plan.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child))

case w @ Window(_, _, _, _, Filter(_, _: Aggregate)) =>
w.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, w.child))

case plan @ Project(_,
_: SubqueryAlias
Expand All @@ -351,7 +399,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
) => plan

case plan: Project =>
wrapChildWithSubquery(plan)
plan.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child))

case w @ Window(_, _, _, _,
_: SubqueryAlias
Expand All @@ -365,26 +413,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
) => w

case w: Window =>
val alias = SQLBuilder.newSubqueryName
val childAttributes = w.child.outputSet
val qualified = w.windowExpressions.map(_.transform {
case a: Attribute if childAttributes.contains(a) =>
a.withQualifiers(alias :: Nil)
}.asInstanceOf[NamedExpression])

w.copy(windowExpressions = qualified, child = SubqueryAlias(alias, w.child))
}

def wrapChildWithSubquery(project: Project): Project = project match {
case Project(projectList, child) =>
val alias = SQLBuilder.newSubqueryName
val childAttributes = child.outputSet
val aliasedProjectList = projectList.map(_.transform {
case a: Attribute if childAttributes.contains(a) =>
a.withQualifiers(alias :: Nil)
}.asInstanceOf[NamedExpression])

Project(aliasedProjectList, SubqueryAlias(alias, child))
w.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, w.child))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,18 +446,88 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
)
}

test("window functions") {
checkHiveQl("SELECT a, SUM(b) OVER (PARTITION BY c ORDER BY d) AS sum FROM parquet_t2")
test("window basic") {
checkHiveQl("SELECT MAX(value) OVER (PARTITION BY key % 3) FROM parquet_t1")
checkHiveQl(
"""
|SELECT a + 1, SUM(b * 2) OVER (PARTITION BY c + d ORDER BY c - d) AS sum
|FROM parquet_t2
|SELECT key, value, ROUND(AVG(key) OVER (), 2)
|FROM parquet_t1 ORDER BY key
""".stripMargin)
checkHiveQl(
"""
|SELECT a, SUM(b) OVER w1 AS sum, AVG(b) over w2 AS avg
|FROM parquet_t2
|WINDOW w1 AS (PARTITION BY c ORDER BY d), w2 AS (PARTITION BY d ORDER BY c)
|SELECT value, MAX(key + 1) OVER (PARTITION BY key % 5 ORDER BY key % 7) AS max
|FROM parquet_t1
""".stripMargin)
}

test("window with different window specification") {
checkHiveQl(
"""
|SELECT key, value,
|DENSE_RANK() OVER (ORDER BY key, value) AS dr,
|MAX(value) OVER (PARTITION BY key ORDER BY key) AS max
|FROM parquet_t1
""".stripMargin)
}

test("window with the same window specification with aggregate + having") {
checkHiveQl(
"""
|SELECT key, value,
|MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max
|FROM parquet_t1 GROUP BY key, value HAVING key > 5
""".stripMargin)
}

test("window with the same window specification with aggregate functions") {
checkHiveQl(
"""
|SELECT key, value,
|MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max
|FROM parquet_t1 GROUP BY key, value
""".stripMargin)
}

test("window with the same window specification with aggregate") {
checkHiveQl(
"""
|SELECT key, value,
|DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr,
|COUNT(key)
|FROM parquet_t1 GROUP BY key, value
""".stripMargin)
}

test("window with the same window specification without aggregate and filter") {
checkHiveQl(
"""
|SELECT key, value,
|DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr,
|COUNT(key) OVER(DISTRIBUTE BY key SORT BY key, value) AS ca
|FROM parquet_t1
""".stripMargin)
}

test("window clause") {
checkHiveQl(
"""
|SELECT key, MAX(value) OVER w1 AS MAX, MIN(value) OVER w2 AS min
|FROM parquet_t1
|WINDOW w1 AS (PARTITION BY key % 5 ORDER BY key), w2 AS (PARTITION BY key % 6)
""".stripMargin)
}

test("special window functions") {
checkHiveQl(
"""
|SELECT
| RANK() OVER w,
| PERCENT_RANK() OVER w,
| DENSE_RANK() OVER w,
| ROW_NUMBER() OVER w,
| CUME_DIST() OVER w
|FROM parquet_t1
|WINDOW w AS (PARTITION BY key % 5 ORDER BY key)
""".stripMargin)
}
}

0 comments on commit e037814

Please sign in to comment.