Skip to content
Permalink
Browse files

[SPARK-27653][SQL] Add max_by() and min_by() SQL aggregate functions

## What changes were proposed in this pull request?

This PR goes to add `max_by()` and `min_by()` SQL aggregate functions.

Quoting from the [Presto docs](https://prestodb.github.io/docs/current/functions/aggregate.html#max_by)

> max_by(x, y) → [same as x]
> Returns the value of x associated with the maximum value of y over all input values.

`min_by()` works similarly.

## How was this patch tested?

Added tests.

Closes #24557 from viirya/SPARK-27653.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information...
viirya authored and cloud-fan committed May 13, 2019
1 parent 126310c commit d169b0aac369d373968a6c66ee43440a2ad751a5
@@ -289,8 +289,10 @@ object FunctionRegistry {
expression[Last]("last"),
expression[Last]("last_value"),
expression[Max]("max"),
expression[MaxBy]("max_by"),
expression[Average]("mean"),
expression[Min]("min"),
expression[MinBy]("min_by"),
expression[Percentile]("percentile"),
expression[Skewness]("skewness"),
expression[ApproximatePercentile]("percentile_approx"),
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
* The shared abstract superclass for `MaxBy` and `MinBy` SQL aggregate functions.
*/
abstract class MaxMinBy extends DeclarativeAggregate {

def valueExpr: Expression
def orderingExpr: Expression

protected def funcName: String
// The predicate compares two ordering values.
protected def predicate(oldExpr: Expression, newExpr: Expression): Expression
// The arithmetic expression returns greatest/least value of all parameters.
// Used to pick up updated ordering value.
protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression

override def children: Seq[Expression] = valueExpr :: orderingExpr :: Nil

override def nullable: Boolean = true

// Return data type.
override def dataType: DataType = valueExpr.dataType

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $funcName")

// The attributes used to keep extremum (max or min) and associated aggregated values.
private lazy val extremumOrdering =
AttributeReference("extremumOrdering", orderingExpr.dataType)()
private lazy val valueWithExtremumOrdering =
AttributeReference("valueWithExtremumOrdering", valueExpr.dataType)()

override lazy val aggBufferAttributes: Seq[AttributeReference] =
valueWithExtremumOrdering :: extremumOrdering :: Nil

private lazy val nullValue = Literal.create(null, valueExpr.dataType)
private lazy val nullOrdering = Literal.create(null, orderingExpr.dataType)

override lazy val initialValues: Seq[Literal] = Seq(
/* valueWithExtremumOrdering = */ nullValue,
/* extremumOrdering = */ nullOrdering
)

override lazy val updateExpressions: Seq[Expression] = Seq(
/* valueWithExtremumOrdering = */
CaseWhen(
(extremumOrdering.isNull && orderingExpr.isNull, nullValue) ::
(extremumOrdering.isNull, valueExpr) ::
(orderingExpr.isNull, valueWithExtremumOrdering) :: Nil,
If(predicate(extremumOrdering, orderingExpr), valueWithExtremumOrdering, valueExpr)
),
/* extremumOrdering = */ orderingUpdater(extremumOrdering, orderingExpr)
)

override lazy val mergeExpressions: Seq[Expression] = Seq(
/* valueWithExtremumOrdering = */
CaseWhen(
(extremumOrdering.left.isNull && extremumOrdering.right.isNull, nullValue) ::
(extremumOrdering.left.isNull, valueWithExtremumOrdering.right) ::
(extremumOrdering.right.isNull, valueWithExtremumOrdering.left) :: Nil,
If(predicate(extremumOrdering.left, extremumOrdering.right),
valueWithExtremumOrdering.left, valueWithExtremumOrdering.right)
),
/* extremumOrdering = */ orderingUpdater(extremumOrdering.left, extremumOrdering.right)
)

override lazy val evaluateExpression: AttributeReference = valueWithExtremumOrdering
}

@ExpressionDescription(
usage = "_FUNC_(x, y) - Returns the value of `x` associated with the maximum value of `y`.",
examples = """
Examples:
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
b
""",
since = "3.0")
case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
override protected def funcName: String = "max_by"

override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression =
oldExpr > newExpr

override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
greatest(oldExpr, newExpr)
}

@ExpressionDescription(
usage = "_FUNC_(x, y) - Returns the value of `x` associated with the minimum value of `y`.",
examples = """
Examples:
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
a
""",
since = "3.0")
case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
override protected def funcName: String = "min_by"

override protected def predicate(oldExpr: Expression, newExpr: Expression): Expression =
oldExpr < newExpr

override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
least(oldExpr, newExpr)
}
@@ -782,4 +782,116 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
val countAndDistinct = df.select(count("*"), countDistinct("*"))
checkAnswer(countAndDistinct, Row(100000, 100))
}

test("max_by") {
val yearOfMaxEarnings =
sql("SELECT course, max_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMaxEarnings, Row("dotNET", 2013) :: Row("Java", 2013) :: Nil)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
Row("b") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
Row("b") :: Nil
)

checkAnswer(
sql("SELECT max_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
Row(null) :: Nil
)

// structs as ordering value.
checkAnswer(
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("select max_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', null)) AS tab(x, y)"),
Row("b") :: Nil
)

withTempView("tempView") {
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
.toDF("x", "y")
.select($"x", map($"x", $"y").as("y"))
.createOrReplaceTempView("tempView")
val error = intercept[AnalysisException] {
sql("SELECT max_by(x, y) FROM tempView").show
}
assert(
error.message.contains("function max_by does not support ordering on type map<int,string>"))
}
}

test("min_by") {
val yearOfMinEarnings =
sql("SELECT course, min_by(year, earnings) FROM courseSales GROUP BY course")
checkAnswer(yearOfMinEarnings, Row("dotNET", 2012) :: Row("Java", 2012) :: Nil)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)), (('c', 20)) AS tab(x, y)"),
Row("c") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', null)) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("SELECT min_by(x, y) FROM VALUES (('a', null)), (('b', null)) AS tab(x, y)"),
Row(null) :: Nil
)

// structs as ordering value.
checkAnswer(
sql("select min_by(x, y) FROM VALUES (('a', (10, 20))), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("a") :: Nil
)

checkAnswer(
sql("select min_by(x, y) FROM VALUES (('a', null)), (('b', (10, 50))), " +
"(('c', (10, 60))) AS tab(x, y)"),
Row("b") :: Nil
)

withTempView("tempView") {
val dfWithMap = Seq((0, "a"), (1, "b"), (2, "c"))
.toDF("x", "y")
.select($"x", map($"x", $"y").as("y"))
.createOrReplaceTempView("tempView")
val error = intercept[AnalysisException] {
sql("SELECT min_by(x, y) FROM tempView").show
}
assert(
error.message.contains("function min_by does not support ordering on type map<int,string>"))
}
}
}

0 comments on commit d169b0a

Please sign in to comment.
You can’t perform that action at this time.