Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10735] [SQL] Generate aggregation w/o grouping keys [WIP] #10786

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ object FunctionRegistry {
expression[Average]("mean"),
expression[Min]("min"),
expression[StddevSamp]("stddev"),
expression[StddevSamp1]("stddev1"),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w

override def dataType: DataType = DoubleType

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
Expand Down Expand Up @@ -109,7 +109,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
* Update the central moments buffer.
*/
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val v = Cast(child, DoubleType).eval(input)
val v = child.eval(input)
if (v != null) {
val updateValue = v match {
case d: Double => d
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

case class StddevSamp(child: Expression,
mutableAggBufferOffset: Int = 0,
Expand Down Expand Up @@ -79,3 +81,116 @@ case class StddevPop(
}
}
}

// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is included for benchmark, could be used to replace the imperative one (should be done in follow up PR or this one?)


def isSample: Boolean

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

override def dataType: DataType = resultType

// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select stddev(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, NullType))

private val resultType = DoubleType

private val count = AttributeReference("count", resultType, false)()
private val avg = AttributeReference("avg", resultType, false)()
private val mk = AttributeReference("mk", resultType, false)()

override val aggBufferAttributes = count :: avg :: mk :: Nil

override val initialValues: Seq[Expression] = Seq(
/* count = */ Literal(0.0),
/* avg = */ Literal(0.0),
/* mk = */ Literal(0.0)
)

override val updateExpressions: Seq[Expression] = {
val newCount = count + Literal(1.0)

// update average
// avg = avg + (value - avg)/count
val newAvg = avg + (child - avg) / newCount

// update sum ofference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
val newMk = mk + (child - avg) * (child - newAvg)

if (child.nullable) {
Seq(
/* count = */ If(IsNull(child), count, newCount),
/* avg = */ If(IsNull(child), avg, newAvg),
/* mk = */ If(IsNull(child), mk, newMk)
)
} else {
Seq(
/* count = */ newCount,
/* avg = */ newAvg,
/* mk = */ newMk
)
}
}

override val mergeExpressions: Seq[Expression] = {

// count merge
val newCount = count.left + count.right

// average merge
val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount

// update sum of square differences
val newMk = {
val avgDelta = avg.right - avg.left
val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount
mk.left + mk.right + mkDelta
}

Seq(
/* count = */ newCount,
/* avg = */ newAvg,
/* mk = */ newMk
)
}

override val evaluateExpression: Expression = {
// when count == 0, return null
// when count == 1, return 0
// when count >1
// stddev_samp = sqrt (mk/(count -1))
// stddev_pop = sqrt (mk/count)
val varCol =
if (isSample) {
mk / (count - Literal(1.0))
} else {
mk / count
}

If(EqualTo(count, Literal(0.0)), Literal.create(null, resultType),
If(EqualTo(count, Literal(1.0)), Literal(0.0),
Sqrt(varCol)))
}
}

// Compute the population standard deviation of a column
case class StddevPop1(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = false
override def prettyName: String = "stddev_pop"
}

// Compute the sample standard deviation of a column
case class StddevSamp1(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = true
override def prettyName: String = "stddev_samp"
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,40 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
val projectionCodes = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
val (validExpr, index) = expressions.zipWithIndex.filter {
case (NoOp, _) => false
case _ => true
}.unzip
val exprVals = ctx.generateExpressions(validExpr, true)
val projectionCodes = exprVals.zip(index).map {
case (ev, i) =>
val e = expressions(i)
if (e.nullable) {
val isNull = s"isNull_$i"
val value = s"value_$i"
ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
ctx.addMutableState(ctx.javaType(e.dataType), value,
s"this.$value = ${ctx.defaultValue(e.dataType)};")
s"""
${evaluationCode.code}
this.$isNull = ${evaluationCode.isNull};
this.$value = ${evaluationCode.value};
${ev.code}
this.$isNull = ${ev.isNull};
this.$value = ${ev.value};
"""
} else {
val value = s"value_$i"
ctx.addMutableState(ctx.javaType(e.dataType), value,
s"this.$value = ${ctx.defaultValue(e.dataType)};")
s"""
${evaluationCode.code}
this.$value = ${evaluationCode.value};
${ev.code}
this.$value = ${ev.value};
"""
}
}
val updates = expressions.zipWithIndex.map {
case (NoOp, _) => ""

// Reset the subexpression values for each row.
val subexprReset = ctx.subExprResetVariables.mkString("\n")

val updates = validExpr.zip(index).map {
case (e, i) =>
if (e.nullable) {
if (e.dataType.isInstanceOf[DecimalType]) {
Expand Down Expand Up @@ -128,6 +135,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

public java.lang.Object apply(java.lang.Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
$subexprReset
$allProjections
// copy all the results into MutableRow
$allUpdates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val code =
s"""
$bufferHolder.reset();
$subexprReset
${subexprReset.trim}
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}

$result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize());
"""
ExprCode(code, "false", result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class GroupedData protected[sql](
UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false)
case "stddev" | "std" =>
UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false)
case "stddev1" | "std" =>
UnresolvedFunction("stddev1", inputExpr :: Nil, isDistinct = false)
// Also special handle count because we need to take care count(*).
case "count" | "size" =>
// Turn count(*) into count(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.rules.Rule

Expand Down Expand Up @@ -190,7 +191,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
"""
// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")
CodeGenerator.compile(source)
// CodeGenerator.compile(source)

rdd.mapPartitions { iter =>
val clazz = CodeGenerator.compile(source)
Expand Down Expand Up @@ -264,12 +265,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
*/
private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {

private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
case e: ImperativeAggregate => true
case e: CodegenFallback => false
case e => true
}


private def supportCodegen(plan: SparkPlan): Boolean = plan match {
case plan: CodegenSupport if plan.supportCodegen =>
// Non-leaf with CodegenFallback does not work with whole stage codegen
val willFallback = plan.expressions.exists(
_.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined
)
val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
// the generated code will be huge if there are too many columns
val haveManyColumns = plan.output.length > 200
!willFallback && !haveManyColumns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,53 +66,9 @@ abstract class AggregationIterator(
s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.")
}

// Initialize all AggregateFunctions by binding references if necessary,
// and set inputBufferOffset and mutableBufferOffset.
protected def initializeAggregateFunctions(
expressions: Seq[AggregateExpression],
startingInputBufferOffset: Int): Array[AggregateFunction] = {
var mutableBufferOffset = 0
var inputBufferOffset: Int = startingInputBufferOffset
val functions = new Array[AggregateFunction](expressions.length)
var i = 0
while (i < expressions.length) {
val func = expressions(i).aggregateFunction
val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
// We need to create BoundReferences if the function is not an
// expression-based aggregate function (it does not support code-gen) and the mode of
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
BindReferences.bindReference(func, inputAttributes)
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
val updatedFunc = func match {
case function: ImperativeAggregate =>
function.withNewInputAggBufferOffset(inputBufferOffset)
case function => function
}
inputBufferOffset += func.aggBufferSchema.length
updatedFunc
}
val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
case function: ImperativeAggregate =>
// Set mutableBufferOffset for this function. It is important that setting
// mutableBufferOffset happens after all potential bindReference operations
// because bindReference will create a new instance of the function.
function.withNewMutableAggBufferOffset(mutableBufferOffset)
case function => function
}
mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
functions(i) = funcWithUpdatedAggBufferOffset
i += 1
}
functions
}

protected val aggregateFunctions: Array[AggregateFunction] =
initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset)
AggregationIterator.initializeAggregateFunctions(
aggregateExpressions, inputAttributes, initialInputBufferOffset)

// Positions of those imperative aggregate functions in allAggregateFunctions.
// For example, we have func1, func2, func3, func4 in aggregateFunctions, and
Expand Down Expand Up @@ -259,3 +215,51 @@ abstract class AggregationIterator(
}
}
}

object AggregationIterator {
// Initialize all AggregateFunctions by binding references if necessary,
// and set inputBufferOffset and mutableBufferOffset.
def initializeAggregateFunctions(
expressions: Seq[AggregateExpression],
inputAttributes: Seq[Attribute],
startingInputBufferOffset: Int): Array[AggregateFunction] = {
var mutableBufferOffset = 0
var inputBufferOffset: Int = startingInputBufferOffset
val functions = new Array[AggregateFunction](expressions.length)
var i = 0
while (i < expressions.length) {
val func = expressions(i).aggregateFunction
val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
// We need to create BoundReferences if the function is not an
// expression-based aggregate function (it does not support code-gen) and the mode of
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
BindReferences.bindReference(func, inputAttributes)
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
val updatedFunc = func match {
case function: ImperativeAggregate =>
function.withNewInputAggBufferOffset(inputBufferOffset)
case function => function
}
inputBufferOffset += func.aggBufferSchema.length
updatedFunc
}
val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
case function: ImperativeAggregate =>
// Set mutableBufferOffset for this function. It is important that setting
// mutableBufferOffset happens after all potential bindReference operations
// because bindReference will create a new instance of the function.
function.withNewMutableAggBufferOffset(mutableBufferOffset)
case function => function
}
mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
functions(i) = funcWithUpdatedAggBufferOffset
i += 1
}
functions
}
}
Loading