Skip to content

Commit

Permalink
[SPARK-37138][SQL] Support ANSI Interval types in ApproxCountDistinct…
Browse files Browse the repository at this point in the history
…ForIntervals/ApproximatePercentile/Percentile

### What changes were proposed in this pull request?

Support Ansi Interval types in the agg expressions:
- ApproxCountDistinctForIntervals
- ApproximatePercentile
- Percentile

### Why are the changes needed?
To improve user experience with Spark SQL.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added new UT.

Closes #34412 from AngersZhuuuu/SPARK-37138.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
AngersZhuuuu authored and MaxGekk committed Oct 30, 2021
1 parent b0548c6 commit 08123a3
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 31 deletions.
Expand Up @@ -61,7 +61,8 @@ case class ApproxCountDistinctForIntervals(
}

override def inputTypes: Seq[AbstractDataType] = {
Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType), ArrayType)
Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType,
YearMonthIntervalType, DayTimeIntervalType), ArrayType)
}

// Mark as lazy so that endpointsExpression is not evaluated during tree transformation.
Expand All @@ -79,14 +80,16 @@ case class ApproxCountDistinctForIntervals(
TypeCheckFailure("The endpoints provided must be constant literals")
} else {
endpointsExpression.dataType match {
case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType, _) =>
case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType |
_: AnsiIntervalType, _) =>
if (endpoints.length < 2) {
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")
} else {
TypeCheckSuccess
}
case _ =>
TypeCheckFailure("Endpoints require (numeric or timestamp or date) type")
TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
"interval year to month or interval day to second) type")
}
}
}
Expand Down Expand Up @@ -120,9 +123,9 @@ case class ApproxCountDistinctForIntervals(
val doubleValue = child.dataType match {
case n: NumericType =>
n.numeric.toDouble(value.asInstanceOf[n.InternalType])
case _: DateType =>
case _: DateType | _: YearMonthIntervalType =>
value.asInstanceOf[Int].toDouble
case TimestampType | TimestampNTZType =>
case TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
value.asInstanceOf[Long].toDouble
}

Expand Down
Expand Up @@ -49,15 +49,16 @@ import org.apache.spark.sql.types._
* yields better accuracy, the default value is
* DEFAULT_PERCENTILE_ACCURACY.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric
column `col` which is the smallest value in the ordered `col` values (sorted from least to
greatest) such that no more than `percentage` of `col` values is less than the value
or equal to that value. The value of percentage must be between 0.0 and 1.0. The `accuracy`
parameter (default: 10000) is a positive numeric literal which controls approximation accuracy
at the cost of memory. Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is
the relative error of the approximation.
_FUNC_(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric or
ansi interval column `col` which is the smallest value in the ordered `col` values (sorted
from least to greatest) such that no more than `percentage` of `col` values is less than
the value or equal to that value. The value of percentage must be between 0.0 and 1.0.
The `accuracy` parameter (default: 10000) is a positive numeric literal which controls
approximation accuracy at the cost of memory. Higher value of `accuracy` yields better
accuracy, `1.0/accuracy` is the relative error of the approximation.
When `percentage` is an array, each value of the percentage array must be between 0.0 and 1.0.
In this case, returns the approximate percentile array of column `col` at the given
percentage array.
Expand All @@ -68,9 +69,14 @@ import org.apache.spark.sql.types._
[1,1,0]
> SELECT _FUNC_(col, 0.5, 100) FROM VALUES (0), (6), (7), (9), (10) AS tab(col);
7
> SELECT _FUNC_(col, 0.5, 100) FROM VALUES (INTERVAL '0' MONTH), (INTERVAL '1' MONTH), (INTERVAL '2' MONTH), (INTERVAL '10' MONTH) AS tab(col);
0-1
> SELECT _FUNC_(col, array(0.5, 0.7), 100) FROM VALUES (INTERVAL '0' SECOND), (INTERVAL '1' SECOND), (INTERVAL '2' SECOND), (INTERVAL '10' SECOND) AS tab(col);
[0 00:00:01.000000000,0 00:00:02.000000000]
""",
group = "agg_funcs",
since = "2.1.0")
// scalastyle:on line.size.limit
case class ApproximatePercentile(
child: Expression,
percentageExpression: Expression,
Expand All @@ -94,7 +100,8 @@ case class ApproximatePercentile(
override def inputTypes: Seq[AbstractDataType] = {
// Support NumericType, DateType, TimestampType and TimestampNTZType since their internal types
// are all numeric, and can be easily cast to double for processing.
Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType),
Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType,
YearMonthIntervalType, DayTimeIntervalType),
TypeCollection(DoubleType, ArrayType(DoubleType, containsNull = false)), IntegralType)
}

Expand Down Expand Up @@ -138,8 +145,9 @@ case class ApproximatePercentile(
if (value != null) {
// Convert the value to a double value
val doubleValue = child.dataType match {
case DateType => value.asInstanceOf[Int].toDouble
case TimestampType | TimestampNTZType => value.asInstanceOf[Long].toDouble
case DateType | _: YearMonthIntervalType => value.asInstanceOf[Int].toDouble
case TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
value.asInstanceOf[Long].toDouble
case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType])
case other: DataType =>
throw QueryExecutionErrors.dataTypeUnexpectedError(other)
Expand All @@ -157,8 +165,8 @@ case class ApproximatePercentile(
override def eval(buffer: PercentileDigest): Any = {
val doubleResult = buffer.getPercentiles(percentages)
val result = child.dataType match {
case DateType => doubleResult.map(_.toInt)
case TimestampType | TimestampNTZType => doubleResult.map(_.toLong)
case DateType | _: YearMonthIntervalType => doubleResult.map(_.toInt)
case TimestampType | TimestampNTZType | _: DayTimeIntervalType => doubleResult.map(_.toLong)
case ByteType => doubleResult.map(_.toByte)
case ShortType => doubleResult.map(_.toShort)
case IntegerType => doubleResult.map(_.toInt)
Expand Down
Expand Up @@ -43,12 +43,13 @@ import org.apache.spark.util.collection.OpenHashMap
* percentage values. Each percentage value must be in the range
* [0.0, 1.0].
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage =
"""
_FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric column
`col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The
value of frequency should be positive integral
_FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric
or ansi interval column `col` at the given percentage. The value of percentage must be
between 0.0 and 1.0. The value of frequency should be positive integral
_FUNC_(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact
percentile value array of numeric column `col` at the given percentage(s). Each value
Expand All @@ -62,9 +63,14 @@ import org.apache.spark.util.collection.OpenHashMap
3.0
> SELECT _FUNC_(col, array(0.25, 0.75)) FROM VALUES (0), (10) AS tab(col);
[2.5,7.5]
> SELECT _FUNC_(col, 0.5) FROM VALUES (INTERVAL '0' MONTH), (INTERVAL '10' MONTH) AS tab(col);
5.0
> SELECT _FUNC_(col, array(0.2, 0.5)) FROM VALUES (INTERVAL '0' SECOND), (INTERVAL '10' SECOND) AS tab(col);
[2000000.0,5000000.0]
""",
group = "agg_funcs",
since = "2.1.0")
// scalastyle:on line.size.limit
case class Percentile(
child: Expression,
percentageExpression: Expression,
Expand Down Expand Up @@ -118,7 +124,8 @@ case class Percentile(
case _: ArrayType => ArrayType(DoubleType, false)
case _ => DoubleType
}
Seq(NumericType, percentageExpType, IntegralType)
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType),
percentageExpType, IntegralType)
}

// Check the inputTypes are valid, and the percentageExpression satisfies:
Expand Down Expand Up @@ -191,8 +198,15 @@ case class Percentile(
return Seq.empty
}

val sortedCounts = buffer.toSeq.sortBy(_._1)(
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
val ordering =
if (child.dataType.isInstanceOf[NumericType]) {
child.dataType.asInstanceOf[NumericType].ordering
} else if (child.dataType.isInstanceOf[YearMonthIntervalType]) {
child.dataType.asInstanceOf[YearMonthIntervalType].ordering
} else if (child.dataType.isInstanceOf[DayTimeIntervalType]) {
child.dataType.asInstanceOf[DayTimeIntervalType].ordering
}
val sortedCounts = buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]])
val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
Expand Down
Expand Up @@ -39,7 +39,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
assert(
wrongColumn.checkInputDataTypes() match {
case TypeCheckFailure(msg)
if msg.contains("requires (numeric or timestamp or date or timestamp_ntz) type") => true
if msg.contains("requires (numeric or timestamp or date or timestamp_ntz or " +
"interval year to month or interval day to second) type") => true
case _ => false
})
}
Expand Down Expand Up @@ -69,7 +70,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
assert(wrongEndpoints.checkInputDataTypes() ==
TypeCheckFailure("Endpoints require (numeric or timestamp or date) type"))
TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
"interval year to month or interval day to second) type"))
}

/** Create an ApproxCountDistinctForIntervals instance and an input and output buffer. */
Expand Down
Expand Up @@ -170,8 +170,8 @@ class PercentileSuite extends SparkFunSuite {
val child = AttributeReference("a", dataType)()
val percentile = new Percentile(child, percentage)
assertEqual(percentile.checkInputDataTypes(),
TypeCheckFailure(s"argument 1 requires numeric type, however, " +
s"'a' is of ${dataType.simpleString} type."))
TypeCheckFailure(s"argument 1 requires (numeric or interval year to month or " +
s"interval day to second) type, however, 'a' is of ${dataType.simpleString} type."))
}

val invalidFrequencyDataTypes = Seq(FloatType, DoubleType, BooleanType,
Expand All @@ -184,8 +184,8 @@ class PercentileSuite extends SparkFunSuite {
val frq = AttributeReference("frq", frequencyType)()
val percentile = new Percentile(child, percentage, frq)
assertEqual(percentile.checkInputDataTypes(),
TypeCheckFailure(s"argument 1 requires numeric type, however, " +
s"'a' is of ${dataType.simpleString} type."))
TypeCheckFailure(s"argument 1 requires (numeric or interval year to month or " +
s"interval day to second) type, however, 'a' is of ${dataType.simpleString} type."))
}

for(dataType <- validDataTypes;
Expand Down
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import java.time.{Duration, Period}

import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
Expand Down Expand Up @@ -58,4 +60,30 @@ class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSpa
}
}
}

test("SPARK-37138: Support Ansi Interval type in ApproxCountDistinctForIntervals") {
val table = "approx_count_distinct_for_ansi_intervals_tbl"
withTable(table) {
Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
(Period.ofMonths(200), Duration.ofSeconds(200L)),
(Period.ofMonths(300), Duration.ofSeconds(300L)))
.toDF("col1", "col2").createOrReplaceTempView(table)
val endpoints = (0 to 5).map(_ / 10)

val relation = spark.table(table).logicalPlan
val ymAttr = relation.output.find(_.name == "col1").get
val ymAggFunc =
ApproxCountDistinctForIntervals(ymAttr, CreateArray(endpoints.map(Literal(_))))
val ymAggExpr = ymAggFunc.toAggregateExpression()
val ymNamedExpr = Alias(ymAggExpr, ymAggExpr.toString)()

val dtAttr = relation.output.find(_.name == "col2").get
val dtAggFunc =
ApproxCountDistinctForIntervals(dtAttr, CreateArray(endpoints.map(Literal(_))))
val dtAggExpr = dtAggFunc.toAggregateExpression()
val dtNamedExpr = Alias(dtAggExpr, dtAggExpr.toString)()
val result = Dataset.ofRows(spark, Aggregate(Nil, Seq(ymNamedExpr, dtNamedExpr), relation))
checkAnswer(result, Row(Array(1, 1, 1, 1, 1), Array(1, 1, 1, 1, 1)))
}
}
}
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql

import java.sql.{Date, Timestamp}
import java.time.LocalDateTime
import java.time.{Duration, LocalDateTime, Period}

import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
Expand All @@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession
class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession {
import testImplicits._

private val table = "percentile_test"
private val table = "percentile_approx"

test("percentile_approx, single percentile value") {
withTempView(table) {
Expand Down Expand Up @@ -319,4 +319,22 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
Row(18, 17, 17, 17))
}
}

test("SPARK-37138: Support Ansi Interval type in ApproximatePercentile") {
withTempView(table) {
Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
(Period.ofMonths(200), Duration.ofSeconds(200L)),
(Period.ofMonths(300), Duration.ofSeconds(300L)))
.toDF("col1", "col2").createOrReplaceTempView(table)
checkAnswer(
spark.sql(
s"""SELECT
| percentile_approx(col1, 0.5),
| SUM(null),
| percentile_approx(col2, 0.5)
|FROM $table
""".stripMargin),
Row(Period.ofMonths(200).normalized(), null, Duration.ofSeconds(200L)))
}
}
}
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.time.{Duration, Period}

import org.apache.spark.sql.test.SharedSparkSession

/**
* End-to-end tests for percentile aggregate function.
*/
class PercentileQuerySuite extends QueryTest with SharedSparkSession {
import testImplicits._

private val table = "percentile_test"

test("SPARK-37138: Support Ansi Interval type in Percentile") {
withTempView(table) {
Seq((Period.ofMonths(100), Duration.ofSeconds(100L)),
(Period.ofMonths(200), Duration.ofSeconds(200L)),
(Period.ofMonths(300), Duration.ofSeconds(300L)))
.toDF("col1", "col2").createOrReplaceTempView(table)
checkAnswer(
spark.sql(
s"""SELECT
| CAST(percentile(col1, 0.5) AS STRING),
| SUM(null),
| CAST(percentile(col2, 0.5) AS STRING)
|FROM $table
""".stripMargin),
Row("200.0", null, "2.0E8"))
}
}
}

0 comments on commit 08123a3

Please sign in to comment.