Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12718][SPARK-13720][SQL] SQL generation support for window functions #11555

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ class Analyzer(
val withWindow = addWindow(windowExpressions, withFilter)

// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map (_.toAttribute)
val finalProjectList = aggregateExprs.map(_.toAttribute)
Project(finalProjectList, withWindow)

case p: LogicalPlan if !p.childrenResolved => p
Expand All @@ -1217,7 +1217,7 @@ class Analyzer(
val withWindow = addWindow(windowExpressions, withAggregate)

// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map (_.toAttribute)
val finalProjectList = aggregateExprs.map(_.toAttribute)
Project(finalProjectList, withWindow)

// We only extract Window Expressions after all expressions of the Project
Expand All @@ -1232,7 +1232,7 @@ class Analyzer(
val withWindow = addWindow(windowExpressions, withProject)

// Finally, generate output columns according to the original projectList.
val finalProjectList = projectList.map (_.toAttribute)
val finalProjectList = projectList.map(_.toAttribute)
Project(finalProjectList, withWindow)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ case class SortOrder(child: Expression, direction: SortDirection)
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable

override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
override def toString: String = s"$child ${direction.sql}"
override def sql: String = child.sql + " " + direction.sql

def isAscending: Boolean = direction == Ascending
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp}
import org.apache.spark.sql.types._

Expand All @@ -30,6 +31,7 @@ sealed trait WindowSpec

/**
* The specification for a window function.
*
* @param partitionSpec It defines the way that input rows are partitioned.
* @param orderSpec It defines the ordering of rows in a partition.
* @param frameSpecification It defines the window frame in a partition.
Expand Down Expand Up @@ -75,6 +77,22 @@ case class WindowSpecDefinition(
override def nullable: Boolean = true
override def foldable: Boolean = false
override def dataType: DataType = throw new UnsupportedOperationException

override def sql: String = {
val partition = if (partitionSpec.isEmpty) {
""
} else {
"PARTITION BY " + partitionSpec.map(_.sql).mkString(", ")
}

val order = if (orderSpec.isEmpty) {
""
} else {
"ORDER BY " + orderSpec.map(_.sql).mkString(", ")
}

s"($partition $order ${frameSpecification.toString})"
}
}

/**
Expand Down Expand Up @@ -278,6 +296,7 @@ case class WindowExpression(
override def nullable: Boolean = windowFunction.nullable

override def toString: String = s"$windowFunction $windowSpec"
override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql
}

/**
Expand Down Expand Up @@ -451,6 +470,7 @@ object SizeBasedWindowFunction {
the window partition.""")
case class RowNumber() extends RowNumberLike {
override val evaluateExpression = rowNumber
override def sql: String = "ROW_NUMBER()"
}

/**
Expand All @@ -470,6 +490,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
// return the same value for equal values in the partition.
override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)
override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType))
override def sql: String = "CUME_DIST()"
}

/**
Expand Down Expand Up @@ -499,12 +520,25 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed, where is the method of sql for NTile?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Thanks!

def this() = this(Literal(1))

override def children: Seq[Expression] = Seq(buckets)

// Validate buckets. Note that this could be relaxed, the bucket value only needs to constant
// for each partition.
buckets.eval() match {
case b: Int if b > 0 => // Ok
case x => throw new AnalysisException(
"Buckets expression must be a foldable positive integer expression: $x")
override def checkInputDataTypes(): TypeCheckResult = {
if (!buckets.foldable) {
return TypeCheckFailure(s"Buckets expression must be foldable, but got $buckets")
}

if (buckets.dataType != IntegerType) {
return TypeCheckFailure(s"Buckets expression must be integer type, but got $buckets")
}

val i = buckets.eval().asInstanceOf[Int]
if (i > 0) {
TypeCheckSuccess
} else {
TypeCheckFailure(s"Buckets expression must be positive, but got: $i")
}
}

private val bucket = AttributeReference("bucket", IntegerType, nullable = false)()
Expand Down Expand Up @@ -608,6 +642,7 @@ abstract class RankLike extends AggregateWindowFunction {
case class Rank(children: Seq[Expression]) extends RankLike {
def this() = this(Nil)
override def withOrder(order: Seq[Expression]): Rank = Rank(order)
override def sql: String = "RANK()"
}

/**
Expand All @@ -632,6 +667,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike {
override val updateExpressions = increaseRank +: children
override val aggBufferAttributes = rank +: orderAttrs
override val initialValues = zero +: orderInit
override def sql: String = "DENSE_RANK()"
}

/**
Expand All @@ -658,4 +694,5 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase
override val evaluateExpression = If(GreaterThan(n, one),
Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)),
Literal(0.0d))
override def sql: String = "PERCENT_RANK()"
}
104 changes: 77 additions & 27 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable
}

/**
* A builder class used to convert a resolved logical plan into a SQL query string. Note that this
* A builder class used to convert a resolved logical plan into a SQL query string. Note that not
* all resolved logical plan are convertible. They either don't have corresponding SQL
* representations (e.g. logical plans that operate on local Scala collections), or are simply not
* supported by this builder (yet).
Expand Down Expand Up @@ -103,6 +103,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: Aggregate =>
aggregateToSQL(p)

case w: Window =>
windowToSQL(w)

case Limit(limitExpr, child) =>
s"${toSQL(child)} LIMIT ${limitExpr.sql}"

Expand All @@ -123,12 +126,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
build(toSQL(p.child), "TABLESAMPLE(" + fraction + " PERCENT)")
}

case p: Filter =>
val whereOrHaving = p.child match {
case Filter(condition, child) =>
val whereOrHaving = child match {
case _: Aggregate => "HAVING"
case _ => "WHERE"
}
build(toSQL(p.child), whereOrHaving, p.condition.sql)
build(toSQL(child), whereOrHaving, condition.sql)

case p @ Distinct(u: Union) if u.children.length > 1 =>
val childrenSql = u.children.map(c => s"(${toSQL(c)})")
Expand Down Expand Up @@ -179,7 +182,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
build(
toSQL(p.child),
if (p.global) "ORDER BY" else "SORT BY",
p.order.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
p.order.map(_.sql).mkString(", ")
)

case p: RepartitionByExpression =>
Expand Down Expand Up @@ -268,9 +271,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
}
}
val groupingSetSQL =
"GROUPING SETS(" +
groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
val groupingSetSQL = "GROUPING SETS(" +
groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"

val aggExprs = agg.aggregateExpressions.map { case expr =>
expr.transformDown {
Expand All @@ -297,22 +299,50 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
)
}

private def windowToSQL(w: Window): String = {
build(
"SELECT",
(w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "),
if (w.child == OneRowRelation) "" else "FROM",
toSQL(w.child)
)
}

object Canonicalizer extends RuleExecutor[LogicalPlan] {
override protected def batches: Seq[Batch] = Seq(
Batch("Canonicalizer", FixedPoint(100),
Batch("Collapse Project", FixedPoint(100),
// The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over
// `Aggregate`s to perform type casting. This rule merges these `Project`s into
// `Aggregate`s.
CollapseProject,

CollapseProject),
Batch("Recover Scoping Info", Once,
// Used to handle other auxiliary `Project`s added by analyzer (e.g.
// `ResolveAggregateFunctions` rule)
RecoverScopingInfo
AddSubquery,
// Previous rule will add extra sub-queries, this rule is used to re-propagate and update
// the qualifiers bottom up, e.g.:
//
// Sort
// ordering = t1.a
// Project
// projectList = [t1.a, t1.b]
// Subquery gen_subquery
// child ...
//
// will be transformed to:
//
// Sort
// ordering = gen_subquery.a
// Project
// projectList = [gen_subquery.a, gen_subquery.b]
// Subquery gen_subquery
// child ...
UpdateQualifiers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which test is for this new rule?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/cloud-fan/spark/blob/window/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala#L454-L458

The above test is for verifying this rule. The JIRA SPARK-13720 describes the reason why we need to add this rule.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@cloud-fan Maybe it is good to add an example at here?

)
)

object RecoverScopingInfo extends Rule[LogicalPlan] {
override def apply(tree: LogicalPlan): LogicalPlan = tree transform {
object AddSubquery extends Rule[LogicalPlan] {
override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp {
// This branch handles aggregate functions within HAVING clauses. For example:
//
// SELECT key FROM src GROUP BY key HAVING max(value) > "val_255"
Expand All @@ -324,8 +354,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
// +- Filter ...
// +- Aggregate ...
// +- MetastoreRelation default, src, None
case plan @ Project(_, Filter(_, _: Aggregate)) =>
wrapChildWithSubquery(plan)
case plan @ Project(_, Filter(_, _: Aggregate)) => wrapChildWithSubquery(plan)

case w @ Window(_, _, _, _, Filter(_, _: Aggregate)) => wrapChildWithSubquery(w)

case plan @ Project(_,
_: SubqueryAlias
Expand All @@ -338,20 +369,39 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
| _: Sample
) => plan

case plan: Project =>
wrapChildWithSubquery(plan)
case plan: Project => wrapChildWithSubquery(plan)

// We will generate "SELECT ... FROM ..." for Window operator, so its child operator should
// be able to put in the FROM clause, or we wrap it with a subquery.
case w @ Window(_, _, _, _,
_: SubqueryAlias
| _: Filter
| _: Join
| _: MetastoreRelation
| OneRowRelation
| _: LocalLimit
| _: GlobalLimit
| _: Sample
) => w
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment to explain why we need this rule?


case w: Window => wrapChildWithSubquery(w)
}

def wrapChildWithSubquery(project: Project): Project = project match {
case Project(projectList, child) =>
val alias = SQLBuilder.newSubqueryName
val childAttributes = child.outputSet
val aliasedProjectList = projectList.map(_.transform {
case a: Attribute if childAttributes.contains(a) =>
a.withQualifiers(alias :: Nil)
}.asInstanceOf[NamedExpression])
private def wrapChildWithSubquery(plan: UnaryNode): LogicalPlan = {
val newChild = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child)
plan.withNewChildren(Seq(newChild))
}
}

Project(aliasedProjectList, SubqueryAlias(alias, child))
object UpdateQualifiers extends Rule[LogicalPlan] {
override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp {
case plan =>
val inputAttributes = plan.children.flatMap(_.output)
plan transformExpressions {
case a: AttributeReference if !plan.producedAttributes.contains(a) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile Not related to this PR. What is difference between producedAttributes and the outputSet?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

producedAttributes is the list of attributes that are added by this operator. For example, Generate will produce some attributes that do not exist in the child node.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like outputSet should also have that kind of Attributes?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we do not need to add qualifiers for the attributes in producedAttributes. Thus, we keep them untouched.

val qualifier = inputAttributes.find(_ semanticEquals a).map(_.qualifiers)
a.withQualifiers(qualifier.getOrElse(Nil))
}
}
}
}
Expand Down
Loading