Skip to content

Commit

Permalink
[SPARK-22957] ApproxQuantile breaks if the number of rows exceeds MaxInt
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

32bit Int was used for row rank.
That overflowed in a dataframe with more than 2B rows.

## How was this patch tested?

Added test, but ignored, as it takes 4 minutes.

Author: Juliusz Sompolski <julek@databricks.com>

Closes #20152 from juliuszsompolski/SPARK-22957.
  • Loading branch information
juliuszsompolski authored and cloud-fan committed Jan 5, 2018
1 parent 0428368 commit df7fc3e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ object ApproximatePercentile {
Ints.BYTES + Doubles.BYTES + Longs.BYTES +
// length of summary.sampled
Ints.BYTES +
// summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)]
summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES)
// summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)]
summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES)
}

final def serialize(obj: PercentileDigest): Array[Byte] = {
Expand All @@ -312,8 +312,8 @@ object ApproximatePercentile {
while (i < summary.sampled.length) {
val stat = summary.sampled(i)
buffer.putDouble(stat.value)
buffer.putInt(stat.g)
buffer.putInt(stat.delta)
buffer.putLong(stat.g)
buffer.putLong(stat.delta)
i += 1
}
buffer.array()
Expand All @@ -330,8 +330,8 @@ object ApproximatePercentile {
var i = 0
while (i < sampledLength) {
val value = buffer.getDouble()
val g = buffer.getInt()
val delta = buffer.getInt()
val g = buffer.getLong()
val delta = buffer.getLong()
sampled(i) = Stats(value, g, delta)
i += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class QuantileSummaries(
if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) {
0
} else {
math.floor(2 * relativeError * currentCount).toInt
math.floor(2 * relativeError * currentCount).toLong
}

val tuple = Stats(currentSample, 1, delta)
Expand Down Expand Up @@ -192,10 +192,10 @@ class QuantileSummaries(
}

// Target rank
val rank = math.ceil(quantile * count).toInt
val rank = math.ceil(quantile * count).toLong
val targetError = relativeError * count
// Minimum rank at current sample
var minRank = 0
var minRank = 0L
var i = 0
while (i < sampled.length - 1) {
val curSample = sampled(i)
Expand Down Expand Up @@ -235,7 +235,7 @@ object QuantileSummaries {
* @param g the minimum rank jump from the previous value's minimum rank
* @param delta the maximum span of the rank.
*/
case class Stats(value: Double, g: Int, delta: Int)
case class Stats(value: Double, g: Long, delta: Long)

private def compressImmut(
currentSamples: IndexedSeq[Stats],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
assert(res2(1).isEmpty)
}

// SPARK-22957: check for 32bit overflow when computing rank.
// ignored - takes 4 minutes to run.
ignore("approx quantile 4: test for Int overflow") {
val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05)
assert(res(0) > 2200000000.0)
assert(res(1) > 2200000000.0)
}

test("crosstab") {
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") {
val rng = new Random()
Expand Down

0 comments on commit df7fc3e

Please sign in to comment.