Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12718][SPARK-13720][SQL] SQL generation support for window functions #11555

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ class Analyzer(

// Finally, we create a Project to output currentChild's output
// newExpressionsWithWindowFunctions.
Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild)
Project(child.output ++ newExpressionsWithWindowFunctions, currentChild)
} // end of addWindow

// We have to use transformDown at here to make sure the rule of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ case class SortOrder(child: Expression, direction: SortDirection)
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable

override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
override def toString: String = s"$child ${direction.sql}"
override def sql: String = child.sql + " " + direction.sql

def isAscending: Boolean = direction == Ascending
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ case class WindowSpecDefinition(
override def nullable: Boolean = true
override def foldable: Boolean = false
override def dataType: DataType = throw new UnsupportedOperationException

override def sql: String = {
val partition = if (partitionSpec.isEmpty) {
""
} else {
"PARTITION BY " + partitionSpec.map(_.sql).mkString(", ")
}

val order = if (orderSpec.isEmpty) {
""
} else {
"ORDER BY " + orderSpec.map(_.sql).mkString(", ")
}

s"($partition $order ${frameSpecification.toString})"
}
}

/**
Expand Down Expand Up @@ -278,6 +294,7 @@ case class WindowExpression(
override def nullable: Boolean = windowFunction.nullable

override def toString: String = s"$windowFunction $windowSpec"
override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql
}

/**
Expand Down Expand Up @@ -451,6 +468,7 @@ object SizeBasedWindowFunction {
the window partition.""")
case class RowNumber() extends RowNumberLike {
override val evaluateExpression = rowNumber
override def sql: String = "ROW_NUMBER()"
}

/**
Expand All @@ -470,6 +488,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
// return the same value for equal values in the partition.
override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)
override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType))
override def sql: String = "CUME_DIST()"
}

/**
Expand Down Expand Up @@ -608,6 +627,7 @@ abstract class RankLike extends AggregateWindowFunction {
case class Rank(children: Seq[Expression]) extends RankLike {
def this() = this(Nil)
override def withOrder(order: Seq[Expression]): Rank = Rank(order)
override def sql: String = "RANK()"
}

/**
Expand All @@ -632,6 +652,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike {
override val updateExpressions = increaseRank +: children
override val aggBufferAttributes = rank +: orderAttrs
override val initialValues = zero +: orderInit
override def sql: String = "DENSE_RANK()"
}

/**
Expand All @@ -658,4 +679,5 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase
override val evaluateExpression = If(GreaterThan(n, one),
Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)),
Literal(0.0d))
override def sql: String = "PERCENT_RANK()"
}
134 changes: 105 additions & 29 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable
}

/**
* A builder class used to convert a resolved logical plan into a SQL query string. Note that this
* A builder class used to convert a resolved logical plan into a SQL query string. Note that not
* all resolved logical plan are convertible. They either don't have corresponding SQL
* representations (e.g. logical plans that operate on local Scala collections), or are simply not
* supported by this builder (yet).
Expand Down Expand Up @@ -103,8 +103,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: Aggregate =>
aggregateToSQL(p)

case w: Window =>
windowToSQL(w)

case Limit(limitExpr, child) =>
s"${toSQL(child)} LIMIT ${limitExpr.sql}"
s"${toSQL(child)} LIMIT ${exprToSQL(limitExpr, child)}"

case p: Sample if p.isTableSample =>
val fraction = math.min(100, math.max(0, (p.upperBound - p.lowerBound) * 100))
Expand All @@ -123,12 +126,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
build(toSQL(p.child), "TABLESAMPLE(" + fraction + " PERCENT)")
}

case p: Filter =>
val whereOrHaving = p.child match {
case Filter(condition, child) =>
val whereOrHaving = child match {
case _: Aggregate => "HAVING"
case _ => "WHERE"
}
build(toSQL(p.child), whereOrHaving, p.condition.sql)
build(toSQL(child), whereOrHaving, exprToSQL(condition, child))

case p @ Distinct(u: Union) if u.children.length > 1 =>
val childrenSql = u.children.map(c => s"(${toSQL(c)})")
Expand Down Expand Up @@ -163,7 +166,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
p.joinType.sql,
"JOIN",
toSQL(p.right),
p.condition.map(" ON " + _.sql).getOrElse(""))
p.condition.map(" ON " + exprToSQL(_, p)).getOrElse(""))

case p: MetastoreRelation =>
build(
Expand All @@ -173,20 +176,20 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi

case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
if orders.map(_.child) == partitionExprs =>
build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", "))
build(toSQL(child), "CLUSTER BY", exprsToSQL(partitionExprs, child))

case p: Sort =>
build(
toSQL(p.child),
if (p.global) "ORDER BY" else "SORT BY",
p.order.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
exprsToSQL(p.order, p.child)
)

case p: RepartitionByExpression =>
build(
toSQL(p.child),
"DISTRIBUTE BY",
p.partitionExpressions.map(_.sql).mkString(", ")
exprsToSQL(p.partitionExpressions, p.child)
)

case OneRowRelation =>
Expand All @@ -204,11 +207,70 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
private def build(segments: String*): String =
segments.map(_.trim).filter(_.nonEmpty).mkString(" ")

/**
* Given a seq of qualifiers(names and their corresponding [[AttributeSet]]), transform the given
* expression tree, if an [[Attribute]] belongs to one of the [[AttributeSet]]s, update its
* qualifier with the corresponding name of the [[AttributeSet]].
*/
private def updateQualifier(
expr: Expression,
qualifiers: Seq[(String, AttributeSet)]): Expression = {
if (qualifiers.isEmpty) {
expr
} else {
expr transform {
case a: Attribute =>
val index = qualifiers.indexWhere {
case (_, inputAttributes) => inputAttributes.contains(a)
}
if (index == -1) {
a
} else {
a.withQualifiers(qualifiers(index)._1 :: Nil)
}
}
}
}

/**
* Finds the outer most [[SubqueryAlias]] nodes in the input logical plan and return their alias
* names and outputSet.
*/
private def findOutermostQualifiers(input: LogicalPlan): Seq[(String, AttributeSet)] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have another alternative. We are facing the same issue everywhere when we add an extra Qualifier or remove an extra Qualifier. How about adding another rule/batch below the existing Batch("Canonicalizer") For example,

      Batch("Replace Qualifier", Once,
        ReplaceQualifier)

The rule is simple. We always can get the qualifier from the inputSet if we are doing in bottom up traversal. I did not do a full test last night. Below is the code draft:

    object ReplaceQualifier extends Rule[LogicalPlan] {
      override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp { case plan =>
          plan transformExpressions {
            case e: AttributeReference => e.withQualifiers(getQualifier(plan.inputSet, e))
          }
      }

      private def getQualifier(inputSet: AttributeSet, e: AttributeReference): Seq[String] = {
        inputSet.collectFirst {
          case a if a.semanticEquals(e) => a.qualifiers
        }.getOrElse(Seq.empty[String])
      }
    }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I like this one :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really a good idea! thanks, updated.

input.collectFirst {
case SubqueryAlias(alias, child) => Seq(alias -> child.outputSet)
case plan => plan.children.flatMap(findOutermostQualifiers)
}.toSeq.flatten
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this method is basically a DFS search for all the outermost SubqueryAlias operators. Maybe the following version is clearer:

def findOutermostQualifiers(input: LogicalPlan): Seq[(String, AttributeSet)] = {
  input.collectFirst {
    case SubqueryAlias(alias, child) => Seq(alias -> child.outputSet)
    case plan => plan.children.flatMap(findOutermostQualifiers)
  }.toSeq.flatten
}


/**
* Converts an expression to SQL string.
*
* Note that we may add extra [[SubqueryAlias]]s to the logical plan, but the qualifiers haven't
* been propagated yet. So here we try to find the corrected qualifiers first, and then update
* the given expression with the qualifiers and finally convert it to SQL string.
*/
private def exprToSQL(e: Expression, input: LogicalPlan): String = {
updateQualifier(e, findOutermostQualifiers(input)).sql
}

/**
* Converts a seq of expressions to SQL string.
*
* Note that we may add extra [[SubqueryAlias]]s to the logical plan, but the qualifiers haven't
* been propagated yet. So here we try to find the corrected qualifiers first, and then update
* the given expressions with the qualifiers and finally convert them to SQL string.
*/
private def exprsToSQL(exprs: Seq[Expression], input: LogicalPlan): String = {
val qualifiers = findOutermostQualifiers(input)
exprs.map(updateQualifier(_, qualifiers).sql).mkString(", ")
}

private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
build(
"SELECT",
if (isDistinct) "DISTINCT" else "",
plan.projectList.map(_.sql).mkString(", "),
exprsToSQL(plan.projectList, plan.child),
if (plan.child == OneRowRelation) "" else "FROM",
toSQL(plan.child)
)
Expand All @@ -218,7 +280,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ")
build(
"SELECT",
plan.aggregateExpressions.map(_.sql).mkString(", "),
exprsToSQL(plan.aggregateExpressions, plan.child),
if (plan.child == OneRowRelation) "" else "FROM",
toSQL(plan.child),
if (groupingSQL.isEmpty) "" else "GROUP BY",
Expand All @@ -241,11 +303,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
expand: Expand,
project: Project): String = {
assert(agg.groupingExpressions.length > 1)
val input = project.child

// The last column of Expand is always grouping ID
val gid = expand.output.last

val numOriginalOutput = project.child.output.length
val numOriginalOutput = input.output.length
// Assumption: Aggregate's groupingExpressions is composed of
// 1) the attributes of aliased group by expressions
// 2) gid, which is always the last one
Expand All @@ -254,7 +317,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
// 1) the original output (Project's child.output),
// 2) the aliased group by expressions.
val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
val groupingSQL = exprsToSQL(groupByExprs, input)

// a map from group by attributes to the original group by expressions.
val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
Expand All @@ -269,8 +332,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
}
}
val groupingSetSQL =
"GROUPING SETS(" +
groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
"GROUPING SETS(" + groupingSet.map(e => "(" + exprsToSQL(e, input) + ")").mkString(", ") + ")"

val aggExprs = agg.aggregateExpressions.map { case expr =>
expr.transformDown {
Expand All @@ -288,15 +350,24 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi

build(
"SELECT",
aggExprs.map(_.sql).mkString(", "),
exprsToSQL(aggExprs, input),
if (agg.child == OneRowRelation) "" else "FROM",
toSQL(project.child),
toSQL(input),
"GROUP BY",
groupingSQL,
groupingSetSQL
)
}

private def windowToSQL(w: Window): String = {
build(
"SELECT",
exprsToSQL(w.child.output ++ w.windowExpressions, w.child),
if (w.child == OneRowRelation) "" else "FROM",
toSQL(w.child)
)
}

object Canonicalizer extends RuleExecutor[LogicalPlan] {
override protected def batches: Seq[Batch] = Seq(
Batch("Canonicalizer", FixedPoint(100),
Expand Down Expand Up @@ -325,7 +396,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
// +- Aggregate ...
// +- MetastoreRelation default, src, None
case plan @ Project(_, Filter(_, _: Aggregate)) =>
wrapChildWithSubquery(plan)
plan.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child))

case w @ Window(_, _, _, _, Filter(_, _: Aggregate)) =>
w.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, w.child))

case plan @ Project(_,
_: SubqueryAlias
Expand All @@ -339,19 +413,21 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
) => plan

case plan: Project =>
wrapChildWithSubquery(plan)
}
plan.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child))

def wrapChildWithSubquery(project: Project): Project = project match {
case Project(projectList, child) =>
val alias = SQLBuilder.newSubqueryName
val childAttributes = child.outputSet
val aliasedProjectList = projectList.map(_.transform {
case a: Attribute if childAttributes.contains(a) =>
a.withQualifiers(alias :: Nil)
}.asInstanceOf[NamedExpression])
case w @ Window(_, _, _, _,
_: SubqueryAlias
| _: Filter
| _: Join
| _: MetastoreRelation
| OneRowRelation
| _: LocalLimit
| _: GlobalLimit
| _: Sample
) => w
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment to explain why we need this rule?


Project(aliasedProjectList, SubqueryAlias(alias, child))
case w: Window =>
w.copy(child = SubqueryAlias(SQLBuilder.newSubqueryName, w.child))
}
}
}
Expand Down
Loading