Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7712] [SQL] Move Window Functions from Hive UDAFS to Spark Native backend #6278

Closed
wants to merge 9 commits into from
Expand Up @@ -74,6 +74,9 @@ class Analyzer(
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolvePartialWindowExpressions ::
ResolveUnspecifiedFrameWindowExpressions ::
ResolveImpliedOrderWindowExpressions ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
Expand Down Expand Up @@ -118,14 +121,14 @@ class Analyzer(
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
case plan => plan.transformExpressions {
case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) =>
case we @ WindowExpression(_, WindowSpecReference(windowName), _, _) =>
val errorMessage =
s"Window specification $windowName is not defined in the WINDOW clause."
val windowSpecDefinition =
windowDefinitions
.get(windowName)
.getOrElse(failAnalysis(errorMessage))
WindowExpression(c, windowSpecDefinition)
we.copy(windowSpec = windowSpecDefinition)
}
}
}
Expand Down Expand Up @@ -502,11 +505,17 @@ class Analyzer(
}

def containsAggregates(exprs: Seq[Expression]): Boolean = {
exprs.foreach(_.foreach {
case agg: AggregateExpression => return true
case _ =>
})
false
// Collect all Windowed Aggregate Expressions.
val blacklist = exprs.flatMap { expr =>
expr.collect {
case WindowExpression(ae: AggregateExpression, _, _, _) => ae
}
}.toSet

// Find the first Aggregate Expression that is not Windowed.
exprs.exists(_.collectFirst {
case ae: AggregateExpression if (!blacklist.contains(ae)) => ae
}.isDefined)
}
}

Expand Down Expand Up @@ -716,28 +725,42 @@ class Analyzer(
withName.toAttribute
}

// Extract Window Specification Expressions.
def extractSpecExpressions(spec: WindowSpecDefinition) = {
val newPartitionSpec = spec.partitionSpec.map(extractExpr(_))
val newOrderSpec = spec.orderSpec.map { so =>
so.copy(child = extractExpr(so.child))
}
spec.copy(partitionSpec = newPartitionSpec,
orderSpec = newOrderSpec)
}

// Now, we extract regular expressions from expressionsWithWindowFunctions
// by using extractExpr.
val seenWindowAggregates = new ArrayBuffer[AggregateExpression]
val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
_.transform {
// Extracts children expressions of a WindowFunction (input parameters of
// a WindowFunction).
case wf : WindowFunction =>
val newChildren = wf.children.map(extractExpr(_))
wf.withNewChildren(newChildren)

// Extracts expressions from the partition spec and order spec.
case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) =>
val newPartitionSpec = partitionSpec.map(extractExpr(_))
val newOrderSpec = orderSpec.map { so =>
val newChild = extractExpr(so.child)
so.copy(child = newChild)
}
wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)

// Add the children of an aggregate expression to the expression list.
case we @ WindowExpression(agg: AggregateExpression,
spec: WindowSpecDefinition, _, _) =>
val newAggChildren = agg.children.map(extractExpr(_))
val newAgg = agg.withNewChildren(newAggChildren)
seenWindowAggregates += newAgg
we.copy(windowFunction = newAgg,
windowSpec = extractSpecExpressions(spec))
// Lead/Lag functions window function are have no aggregating operator. The function
// itself is added to the expression list.
case we @ WindowExpression(e: Expression,
spec @ WindowSpecDefinition(_, _,
SpecifiedWindowFrame(RowFrame,
FrameBoundaryExtractor(l),
FrameBoundaryExtractor(h))), _, _)
if (l == h) =>
we.copy(windowFunction = extractExpr(e),
windowSpec = extractSpecExpressions(spec))
// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
// we need to extract SUM(x).
case agg: AggregateExpression =>
case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>
val withName = Alias(agg, s"_w${extractedExprBuffer.length}")()
extractedExprBuffer += withName
withName.toAttribute
Expand Down Expand Up @@ -784,7 +807,8 @@ class Analyzer(
// Second, we group extractedWindowExprBuffer based on their Window Spec.
val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr =>
val distinctWindowSpec = expr.collect {
case window: WindowExpression => window.windowSpec
case WindowExpression(_, spec: WindowSpecDefinition, _, _) =>
spec.copy(frameSpecification = SpecifiedWindowFrame.unbounded)
}.distinct

// We do a final check and see if we only have a single Window Spec defined in an
Expand Down Expand Up @@ -871,6 +895,81 @@ class Analyzer(
Project(finalProjectList, withWindow)
}
}

object ResolvePartialWindowExpressions extends Rule[LogicalPlan] {
/* Add a parents window specification to a child. */
private def mergeSpec(expr: WindowExpression, spec: WindowSpecDefinition) = {
// Do not add an empty parent spec
if (spec == WindowSpecDefinition.empty) failAnalysis("Cannot replace window expression, " +
"the used specification is empty.")
// Do not replace a non-empty child spec
else if (expr.windowSpec != WindowSpecDefinition.empty && expr.windowSpec != spec) {
failAnalysis("Cannot replace window expression, the target expression already has "
+ "a different valid window specification.")
}
// Create a copy with the updated specifica
else expr.copy(windowSpec = spec)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan =>
q transformExpressionsDown {
// Simple Case, two nested window expressions. The parent expression should contain a
// valid specification, the child expression should contain the function and processing
// instructions. The parent is replaced by the updated child.
case we @ WindowExpression(cwe: WindowExpression, spec: WindowSpecDefinition, _, _) =>
mergeSpec(cwe, spec)
// Subtree Case, a parent window expression and a window-function-containing subtree. The
// specification of the parent window expression gets injected into the child window
// expressions. The parent is then replaced by the transformed subtree.
case we @ WindowExpression(ComposedWindowFunction(subtree),
spec: WindowSpecDefinition, _, _) =>
subtree.transformDown {
// TODO see if there are any use cases when we have a partially defined subtree with
// a different window spec. That will currently fail miserably.
case cwe: WindowExpression => mergeSpec(cwe, spec)
}
// Fully Configured Window Function containing SubTree. Eliminate the expression.
case cwf @ ComposedWindowFunction(subtree) => subtree
}
}
}

object ResolveUnspecifiedFrameWindowExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case q: LogicalPlan =>
q transformExpressions {
// Replace the Spec's frame with the fixed frame.
case we @ WindowExpression(_, spec @ WindowSpecDefinition(_, _, UnspecifiedFrame),
frame: SpecifiedWindowFrame, _) =>
we.copy(windowSpec = spec.copy(frameSpecification = frame))
// Replace the Spec's frame with a default frame.
case we @ WindowExpression(_, spec @ WindowSpecDefinition(_, _,
UnspecifiedFrame), _, _) =>
we.copy(windowSpec = spec.copy(frameSpecification =
SpecifiedWindowFrame.defaultWindowFrame(!spec.orderSpec.isEmpty, true)))
// Fail when two frames are defined and they DO NOT match.
case we @ WindowExpression(_, spec @ WindowSpecDefinition(_, _,
frame: SpecifiedWindowFrame), fixedFrame: SpecifiedWindowFrame, _)
if (frame != fixedFrame) =>
failAnalysis(s"The frame of the window '$frame' does not match the required " +
s"frame '$fixedFrame'")
}
}
}
}

/**
* Add the window ordering specification to implied ordering functions.
*/
object ResolveImpliedOrderWindowExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case q: LogicalPlan =>
q transformExpressions {
case we @ WindowExpression(implied: ImpliedOrderSpec, spec: WindowSpecDefinition, _, _) =>
we.copy(windowFunction = implied.defineOrderSpec(spec.orderSpec.map(_.child)))
}
}
}

/**
Expand Down
Expand Up @@ -65,12 +65,13 @@ trait CheckAnalysis {
failAnalysis(
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")

case WindowExpression(UnresolvedWindowFunction(name, _), _) =>
case b: BinaryExpression if !b.resolved =>
failAnalysis(
s"Could not resolve window function '$name'. " +
"Note that, using window functions currently requires a HiveContext")
s"invalid expression ${b.prettyString} " +
s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}")

case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty =>
case w @ WindowExpression(windowFunction, windowSpec: WindowSpecDefinition, _, _)
if windowSpec.validate.nonEmpty =>
// The window spec is not valid.
val reason = windowSpec.validate.get
failAnalysis(s"Window specification $windowSpec is not valid because $reason")
Expand Down
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.types.IntegerType

import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -152,7 +154,17 @@ object FunctionRegistry {
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[Upper]("ucase"),
expression[Upper]("upper")
expression[Upper]("upper"),

// window functions
leaf("row_number", WindowFunction.rowNumber()),
leaf("rank", WindowFunction.rank()),
leaf("dense_rank", WindowFunction.denseRank()),
leaf("percent_rank", WindowFunction.percentRank()),
leaf("cume_dist", WindowFunction.cumeDist()),
ntile,
lead,
lag
)

val builtin: FunctionRegistry = {
Expand Down Expand Up @@ -185,4 +197,65 @@ object FunctionRegistry {
}
(name, builder)
}

/** Add a leaf expression. */
private def leaf(name: String, value: => Expression): (String, FunctionBuilder) = {
val f = (args: Seq[Expression]) => {
if (!args.isEmpty) {
throw new AnalysisException(s"Invalid number of arguments for function $name")
}
value
}
(name, f)
}

private def ntile: (String, FunctionBuilder) = {
val f = (args: Seq[Expression]) => {
args match {
case IntegerLiteral(buckets) :: Nil =>
WindowFunction.ntile(buckets)
case _ =>
throw new AnalysisException(s"Invalid arguments for function ntile: $args")
}
}
("ntile", f)
}

private def lead: (String, FunctionBuilder) = {
val f = (args: Seq[Expression]) => {
val (e, offset, default) = leadLagParams("lead", args)
WindowFunction.lead(e, offset, default)
}
("lead", f)
}

private def lag: (String, FunctionBuilder) = {
val f = (args: Seq[Expression]) => {
val (e, offset, default) = leadLagParams("lag", args)
WindowFunction.lag(e, offset, default)
}
("lag", f)
}

private def leadLagParams(name: String, args: Seq[Expression]): (Expression, Int, Expression) = {
args match {
case Seq(e: Expression) =>
(e, 1, null)
case Seq(e: Expression, ExtractInteger(offset)) =>
(e, offset, null)
case Seq(e: Expression, ExtractInteger(offset), d: Expression) =>
(e, offset, d)
case _ =>
println(args)
throw new AnalysisException(s"Invalid arguments for function $name: $args")
}
}

object ExtractInteger {
def unapply(e: Expression): Option[Integer] = {
if (e.foldable && e.dataType == IntegerType) {
Some(e.eval(EmptyRow).asInstanceOf[Integer])
} else None
}
}
}