Skip to content

Commit

Permalink
[SPARK-22100][SQL] Make percentile_approx support date/timestamp type…
Browse files Browse the repository at this point in the history
… and change the output type to be the same as input type

## What changes were proposed in this pull request?

The `percentile_approx` function previously accepted numeric type input and output double type results.

But since all numeric types, date and timestamp types are represented as numerics internally, `percentile_approx` can support them easily.

After this PR, it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles.

This change is also required when we generate equi-height histograms for these types.

## How was this patch tested?

Added a new test and modified some existing tests.

Author: Zhenhua Wang <wangzhenhua@huawei.com>

Closes #19321 from wzhfy/approx_percentile_support_types.
  • Loading branch information
wzhfy authored and gatorsmile committed Sep 25, 2017
1 parent 20adf9a commit 365a29b
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 19 deletions.
4 changes: 2 additions & 2 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Expand Up @@ -2538,14 +2538,14 @@ test_that("describe() and summary() on a DataFrame", {

stats2 <- summary(df)
expect_equal(collect(stats2)[5, "summary"], "25%")
expect_equal(collect(stats2)[5, "age"], "30.0")
expect_equal(collect(stats2)[5, "age"], "30")

stats3 <- summary(df, "min", "max", "55.1%")

expect_equal(collect(stats3)[1, "summary"], "min")
expect_equal(collect(stats3)[2, "summary"], "max")
expect_equal(collect(stats3)[3, "summary"], "55.1%")
expect_equal(collect(stats3)[3, "age"], "30.0")
expect_equal(collect(stats3)[3, "age"], "30")

# SPARK-16425: SparkR summary() fails on column of type logical
df <- withColumn(df, "boolean", df$age == 30)
Expand Down
1 change: 1 addition & 0 deletions docs/sql-programming-guide.md
Expand Up @@ -1553,6 +1553,7 @@ options.
## Upgrading From Spark SQL 2.2 to 2.3

- Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
- The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles.

## Upgrading From Spark SQL 2.1 to 2.2

Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/dataframe.py
Expand Up @@ -1038,9 +1038,9 @@ def summary(self, *statistics):
| mean| 3.5| null|
| stddev|2.1213203435596424| null|
| min| 2|Alice|
| 25%| 5.0| null|
| 50%| 5.0| null|
| 75%| 5.0| null|
| 25%| 5| null|
| 50%| 5| null|
| 75%| 5| null|
| max| 5| Bob|
+-------+------------------+-----+
Expand All @@ -1050,8 +1050,8 @@ def summary(self, *statistics):
+-------+---+-----+
| count| 2| 2|
| min| 2|Alice|
| 25%|5.0| null|
| 75%|5.0| null|
| 25%| 5| null|
| 75%| 5| null|
| max| 5| Bob|
+-------+---+-----+
Expand Down
Expand Up @@ -85,7 +85,10 @@ case class ApproximatePercentile(
private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]

override def inputTypes: Seq[AbstractDataType] = {
Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
// Support NumericType, DateType and TimestampType since their internal types are all numeric,
// and can be easily cast to double for processing.
Seq(TypeCollection(NumericType, DateType, TimestampType),
TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
}

// Mark as lazy so that percentageExpression is not evaluated during tree transformation.
Expand Down Expand Up @@ -123,7 +126,15 @@ case class ApproximatePercentile(
val value = child.eval(inputRow)
// Ignore empty rows, for example: percentile_approx(null)
if (value != null) {
buffer.add(value.asInstanceOf[Double])
// Convert the value to a double value
val doubleValue = child.dataType match {
case DateType => value.asInstanceOf[Int].toDouble
case TimestampType => value.asInstanceOf[Long].toDouble
case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType])
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
buffer.add(doubleValue)
}
buffer
}
Expand All @@ -134,7 +145,20 @@ case class ApproximatePercentile(
}

override def eval(buffer: PercentileDigest): Any = {
val result = buffer.getPercentiles(percentages)
val doubleResult = buffer.getPercentiles(percentages)
val result = child.dataType match {
case DateType => doubleResult.map(_.toInt)
case TimestampType => doubleResult.map(_.toLong)
case ByteType => doubleResult.map(_.toByte)
case ShortType => doubleResult.map(_.toShort)
case IntegerType => doubleResult.map(_.toInt)
case LongType => doubleResult.map(_.toLong)
case FloatType => doubleResult.map(_.toFloat)
case DoubleType => doubleResult
case _: DecimalType => doubleResult.map(Decimal(_))
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
if (result.length == 0) {
null
} else if (returnPercentileArray) {
Expand All @@ -155,8 +179,9 @@ case class ApproximatePercentile(
// Returns null for empty inputs
override def nullable: Boolean = true

// The result type is the same as the input type.
override def dataType: DataType = {
if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType
if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType
}

override def prettyName: String = "percentile_approx"
Expand Down
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal}
Expand Down Expand Up @@ -270,7 +270,6 @@ class ApproximatePercentileSuite extends SparkFunSuite {
percentageExpression = percentageExpression,
accuracyExpression = Literal(100))

val result = wrongPercentage.checkInputDataTypes()
assert(
wrongPercentage.checkInputDataTypes() match {
case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true
Expand All @@ -281,7 +280,6 @@ class ApproximatePercentileSuite extends SparkFunSuite {

test("class ApproximatePercentile, automatically add type casting for parameters") {
val testRelation = LocalRelation('a.int)
val analyzer = SimpleAnalyzer

// Compatible accuracy types: Long type and decimal type
val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D))
Expand All @@ -299,7 +297,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
analyzed match {
case Alias(agg: ApproximatePercentile, _) =>
assert(agg.resolved)
assert(agg.child.dataType == DoubleType)
assert(agg.child.dataType == IntegerType)
assert(agg.percentageExpression.dataType == DoubleType ||
agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false))
assert(agg.accuracyExpression.dataType == IntegerType)
Expand Down
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.test.SharedSQLContext

/**
Expand Down Expand Up @@ -67,6 +70,30 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("percentile_approx, different column types") {
withTempView(table) {
val intSeq = 1 to 1000
val data: Seq[(java.math.BigDecimal, Date, Timestamp)] = intSeq.map { i =>
(new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i), DateTimeUtils.toJavaTimestamp(i))
}
data.toDF("cdecimal", "cdate", "ctimestamp").createOrReplaceTempView(table)
checkAnswer(
spark.sql(
s"""SELECT
| percentile_approx(cdecimal, array(0.25, 0.5, 0.75D)),
| percentile_approx(cdate, array(0.25, 0.5, 0.75D)),
| percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D))
|FROM $table
""".stripMargin),
Row(
Seq("250.000000000000000000", "500.000000000000000000", "750.000000000000000000")
.map(i => new java.math.BigDecimal(i)),
Seq(250, 500, 750).map(DateTimeUtils.toJavaDate),
Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong)))
)
}
}

test("percentile_approx, multiple records with the minimum value in a partition") {
withTempView(table) {
spark.sparkContext.makeRDD(Seq(1, 1, 2, 1, 1, 3, 1, 1, 4, 1, 1, 5), 4).toDF("col")
Expand All @@ -88,7 +115,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
val accuracies = Array(1, 10, 100, 1000, 10000)
val errors = accuracies.map { accuracy =>
val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) FROM $table")
val approximatePercentile = df.collect().head.getDouble(0)
val approximatePercentile = df.collect().head.getInt(0)
val error = Math.abs(approximatePercentile - expectedPercentile)
error
}
Expand Down
Expand Up @@ -803,9 +803,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row("mean", null, "33.0", "178.0"),
Row("stddev", null, "19.148542155126762", "11.547005383792516"),
Row("min", "Alice", "16", "164"),
Row("25%", null, "24.0", "176.0"),
Row("50%", null, "24.0", "176.0"),
Row("75%", null, "32.0", "180.0"),
Row("25%", null, "24", "176"),
Row("50%", null, "24", "176"),
Row("75%", null, "32", "180"),
Row("max", "David", "60", "192"))

val emptySummaryResult = Seq(
Expand Down

0 comments on commit 365a29b

Please sign in to comment.