Skip to content

Commit

Permalink
[SPARK-27929][SQL] Make percentile function receive frq of double
Browse files Browse the repository at this point in the history
- Make sql.catalyst.expressions.aggregate.Percentile receive frequency column as double.
- Reduce the number of binary searchs for interpolation from two to one.

existing maven suites have been tested.
```
./build/mvn -Phadoop-2.7 -Phive -Phive-thriftserver -Pmesos -Pyarn -Pkubernetes -Dtest=none -DwildcardSuites=org.apache.spark.sql.hive.execution.HiveWindowFunctionQuerySuite test

./build/mvn -Dtest=none -DwildcardSuites=org.apache.spark.sql.catalyst.expressions.aggregate.PercentileSuite test
```

Signed-off-by: Taeksang Kim <voidbag@gmail.com>
  • Loading branch information
voidbag committed Jun 12, 2019
1 parent 8486680 commit c36194e
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 68 deletions.
Expand Up @@ -45,14 +45,14 @@ import org.apache.spark.util.collection.OpenHashMap
@ExpressionDescription(
usage =
"""
_FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric column
`col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The
value of frequency should be positive integral
_FUNC_(col, percentage [, frequency [, is_int_frequency]]) - Returns the exact
percentile value of numeric column `col` at the given percentage. The value of
percentage must be between 0.0 and 1.0. The value of frequency should be positive numeric
_FUNC_(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact
percentile value array of numeric column `col` at the given percentage(s). Each value
of the percentage array must be between 0.0 and 1.0. The value of frequency should be
positive integral
_FUNC_(col, array(percentage1 [, percentage2]...) [, frequency [, is_int_frequency]]) -
Returns the exact percentile value array of numeric column `col` at the given percentage(s).
Each value of the percentage array must be between 0.0 and 1.0. The value of frequency should
be positive numeric. The value of is_int_frequency should be boolean
""",
examples = """
Expand All @@ -67,16 +67,23 @@ case class Percentile(
child: Expression,
percentageExpression: Expression,
frequencyExpression : Expression,
isIntFreqExpression: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes {
extends TypedImperativeAggregate[OpenHashMap[AnyRef, Double]] with ImplicitCastInputTypes {

def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, Literal(1L), 0, 0)
this(child, percentageExpression, Literal(1L), Literal(true), 0, 0)
}

def this(child: Expression, percentageExpression: Expression, frequency: Expression) = {
this(child, percentageExpression, frequency, 0, 0)
this(child, percentageExpression, frequency, Literal(true), 0, 0)
}

def this(
child: Expression, percentageExpression: Expression,
frequency: Expression, isInt: Expression) = {
this(child, percentageExpression, frequency, isInt, 0, 0)
}

override def prettyName: String = "percentile"
Expand All @@ -98,7 +105,7 @@ case class Percentile(
}

override def children: Seq[Expression] = {
child :: percentageExpression ::frequencyExpression :: Nil
child :: percentageExpression ::frequencyExpression :: isIntFreqExpression :: Nil
}

// Returns null for empty inputs
Expand All @@ -114,7 +121,7 @@ case class Percentile(
case _: ArrayType => ArrayType(DoubleType)
case _ => DoubleType
}
Seq(NumericType, percentageExpType, IntegralType)
Seq(NumericType, percentageExpType, NumericType, BooleanType)
}

// Check the inputTypes are valid, and the percentageExpression satisfies:
Expand Down Expand Up @@ -143,57 +150,61 @@ case class Percentile(
case n: Number => n.doubleValue
}

override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
override def createAggregationBuffer(): OpenHashMap[AnyRef, Double] = {
// Initialize new counts map instance here.
new OpenHashMap[AnyRef, Long]()
new OpenHashMap[AnyRef, Double]()
}

override def update(
buffer: OpenHashMap[AnyRef, Long],
input: InternalRow): OpenHashMap[AnyRef, Long] = {
buffer: OpenHashMap[AnyRef, Double],
input: InternalRow): OpenHashMap[AnyRef, Double] = {
val key = child.eval(input).asInstanceOf[AnyRef]
val frqValue = frequencyExpression.eval(input)

// Null values are ignored in counts map.
if (key != null && frqValue != null) {
val frqLong = frqValue.asInstanceOf[Number].longValue()
val frqDouble = frqValue.asInstanceOf[Number].doubleValue()
// add only when frequency is positive
if (frqLong > 0) {
buffer.changeValue(key, frqLong, _ + frqLong)
} else if (frqLong < 0) {
if (frqDouble > 0) {
buffer.changeValue(key, frqDouble, _ + frqDouble)
} else if (frqDouble < 0) {
throw new SparkException(s"Negative values found in ${frequencyExpression.sql}")
}
}
buffer
}

override def merge(
buffer: OpenHashMap[AnyRef, Long],
other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
buffer: OpenHashMap[AnyRef, Double],
other: OpenHashMap[AnyRef, Double]): OpenHashMap[AnyRef, Double] = {
other.foreach { case (key, count) =>
buffer.changeValue(key, count, _ + count)
}
buffer
}

override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
override def eval(buffer: OpenHashMap[AnyRef, Double]): Any = {
generateOutput(getPercentiles(buffer))
}

private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
private def getPercentiles(buffer: OpenHashMap[AnyRef, Double]): Seq[Double] = {
if (buffer.isEmpty) {
return Seq.empty
}

val sortedCounts = buffer.toSeq.sortBy(_._1)(
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
val accumlatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
val accumlatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0D)) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumlatedCounts.last._2 - 1

val maxPosition = accumlatedCounts.last._2
if (maxPosition == 0D) {
return Seq.empty
}

percentages.map { percentile =>
getPercentile(accumlatedCounts, maxPosition * percentile)
getPercentile(accumlatedCounts, percentile)
}
}

Expand All @@ -213,27 +224,36 @@ case class Percentile(
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
private def getPercentile(aggreCounts: Seq[(AnyRef, Double)], percentile: Double): Double = {
val countsArray = aggreCounts.map(_._2).toArray[Double]
if (this.isIntFreqExpression.eval() == false) {
val position = percentile * aggreCounts.last._2
val tmpLowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, position)
val lowerIndex = math.min(tmpLowerIndex, aggreCounts.size - 1)
val lowerKey = aggreCounts(lowerIndex)._1
val higherIndex = lowerIndex + 1
if (higherIndex >= aggreCounts.size || aggreCounts(lowerIndex)._2 != position) {
// no interpolation needed
return toDoubleValue(lowerKey)
}
val higherKey = aggreCounts(higherIndex)._1
// Linear interpolation to calculate the point on boundary
return 0.5D * toDoubleValue(lowerKey) + 0.5D * toDoubleValue(higherKey)
}
// We may need to do linear interpolation to get the exact percentile
val position = percentile * (aggreCounts.last._2 - 1)
val lower = position.floor.toLong
val higher = position.ceil.toLong

// Use binary search to find the lower and the higher position.
val countsArray = aggreCounts.map(_._2).toArray[Long]
val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1)
val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1)

val lowerKey = aggreCounts(lowerIndex)._1
if (higher == lower) {
// no interpolation needed because position does not have a fraction
return toDoubleValue(lowerKey)
}

// Use binary search to find the higher position
val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1D)
val lowerIndex = higherIndex - 1
val higherKey = aggreCounts(higherIndex)._1
if (higherKey == lowerKey) {
// no interpolation needed because lower position and higher position has the same key
return toDoubleValue(lowerKey)

if (lowerIndex < 0 || aggreCounts(lowerIndex)._2 < lower + 1D) {
// no interpolation needed
return toDoubleValue(higherKey)
}
val lowerKey = aggreCounts(lowerIndex)._1

// Linear interpolation to get the exact percentile
(higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
Expand All @@ -243,19 +263,19 @@ case class Percentile(
* use a binary search to find the index of the position closest to the current value.
*/
private def binarySearchCount(
countsArray: Array[Long], start: Int, end: Int, value: Long): Int = {
countsArray: Array[Double], start: Int, end: Int, value: Double): Int = {
util.Arrays.binarySearch(countsArray, 0, end, value) match {
case ix if ix < 0 => -(ix + 1)
case ix => ix
}
}

override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
override def serialize(obj: OpenHashMap[AnyRef, Double]): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
try {
val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType))
val projection = UnsafeProjection.create(Array[DataType](child.dataType, DoubleType))
// Write pairs in counts map to byte buffer.
obj.foreach { case (key, count) =>
val row = InternalRow.apply(key, count)
Expand All @@ -273,11 +293,11 @@ case class Percentile(
}
}

override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Double] = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
val counts = new OpenHashMap[AnyRef, Long]
val counts = new OpenHashMap[AnyRef, Double]
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
Expand All @@ -287,11 +307,12 @@ case class Percentile(
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
val key = row.get(0, child.dataType)
val count = row.get(1, LongType).asInstanceOf[Long]
val count = row.get(1, DoubleType).asInstanceOf[Double]
counts.update(key, count)
sizeOfNextRow = ins.readInt()
}


counts
} finally {
ins.close()
Expand Down
Expand Up @@ -38,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))

// Check empty serialize and deserialize
val buffer = new OpenHashMap[AnyRef, Long]()
val buffer = new OpenHashMap[AnyRef, Double]()
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))

// Check non-empty buffer serialize and deserialize.
data.foreach { key =>
buffer.changeValue(Integer.valueOf(key), 1L, _ + 1L)
buffer.changeValue(Integer.valueOf(key), 1D, _ + 1D)
}
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
}
Expand All @@ -52,26 +52,27 @@ class PercentileSuite extends SparkFunSuite {
val count = 10000
val percentages = Seq(0, 0.25, 0.5, 0.75, 1)
val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000)
val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), FloatType)
val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_)))
val agg = new Percentile(childExpression, percentageExpression)

// Test with rows without frequency
val rows = (1 to count).map(x => Seq(x))
runTest(agg, rows, expectedPercentiles)

// Test with row with frequency. Second and third columns are frequency in Int and Long
// Test with row with frequency. Second and third columns are frequency in Float and Double
val countForFrequencyTest = 1000
val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong)
val rowsWithFrequency = (1 to countForFrequencyTest).map(x =>
(Seq(x) :+ x.toFloat):+ x.toDouble)
val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)

val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
val frequencyExpressionFloat = BoundReference(1, FloatType, nullable = false)
val aggFloat = new Percentile(childExpression, percentageExpression, frequencyExpressionFloat)
runTest(aggFloat, rowsWithFrequency, expectedPercentilesWithFrquency)

val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
val frequencyExpressionDouble = BoundReference(2, DoubleType, nullable = false)
val aggDouble = new Percentile(childExpression, percentageExpression, frequencyExpressionDouble)
runTest(aggDouble, rowsWithFrequency, expectedPercentilesWithFrquency)

// Run test with Flatten data
val flattenRows = (1 to countForFrequencyTest).flatMap(current =>
Expand Down Expand Up @@ -151,7 +152,7 @@ class PercentileSuite extends SparkFunSuite {
assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
}

val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
for (dataType <- validDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
Expand All @@ -171,9 +172,8 @@ class PercentileSuite extends SparkFunSuite {
s"'`a`' is of ${dataType.simpleString} type."))
}

val invalidFrequencyDataTypes = Seq(FloatType, DoubleType, BooleanType,
StringType, DateType, TimestampType,
CalendarIntervalType, NullType)
val invalidFrequencyDataTypes = Seq(BooleanType, StringType, DateType,
TimestampType, CalendarIntervalType, NullType)

for(dataType <- invalidDataTypes;
frequencyType <- validFrequencyTypes) {
Expand All @@ -191,7 +191,7 @@ class PercentileSuite extends SparkFunSuite {
val frq = AttributeReference("frq", frequencyType)()
val percentile = new Percentile(child, percentage, frq)
assertEqual(percentile.checkInputDataTypes(),
TypeCheckFailure(s"argument 3 requires integral type, however, " +
TypeCheckFailure(s"argument 3 requires numeric type, however, " +
s"'`frq`' is of ${frequencyType.simpleString} type."))
}
}
Expand Down Expand Up @@ -259,7 +259,7 @@ class PercentileSuite extends SparkFunSuite {
assert(agg.eval(buffer) == null)

// Percentile with Frequency column
val frequencyExpression = Cast(BoundReference(1, IntegerType, nullable = true), IntegerType)
val frequencyExpression = Cast(BoundReference(1, IntegerType, nullable = true), DoubleType)
val aggWithFrequency = new Percentile(childExpression, Literal(0.5), frequencyExpression)
val bufferWithFrequency = new GenericInternalRow(new Array[Any](2))
aggWithFrequency.initialize(bufferWithFrequency)
Expand All @@ -285,7 +285,7 @@ class PercentileSuite extends SparkFunSuite {

test("negatives frequency column handling") {
val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
val freqExpression = Cast(BoundReference(1, IntegerType, nullable = true), IntegerType)
val freqExpression = Cast(BoundReference(1, IntegerType, nullable = true), FloatType)
val agg = new Percentile(childExpression, Literal(0.5), freqExpression)
val buffer = new GenericInternalRow(new Array[Any](2))
agg.initialize(buffer)
Expand All @@ -300,7 +300,7 @@ class PercentileSuite extends SparkFunSuite {
}

private def compareEquals(
left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = {
left: OpenHashMap[AnyRef, Double], right: OpenHashMap[AnyRef, Double]): Boolean = {
left.size == right.size && left.forall { case (key, count) =>
right.apply(key) == count
}
Expand Down

0 comments on commit c36194e

Please sign in to comment.