From bf62aed080c9a2b6b46e8ee656c70b5ae76c0d45 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Sat, 28 Apr 2018 15:29:17 +0800 Subject: [PATCH] Fix rate source rowsPerSecond <= rampUpTime corner case --- .../sources/RateStreamProvider.scala | 19 +++++++++++++++---- .../sources/RateStreamProviderSuite.scala | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6bdd492f0cb35..916209e832668 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -107,14 +107,25 @@ object RateStreamProvider { // seconds = 0 1 2 3 4 5 6 // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + val speedDeltaPerSecond = math.max(1, rowsPerSecond / (rampUpTimeSeconds + 1)) + + // If rowsPerSecond is smaller than rampUpTimeSeconds, speed will exceed the rowsPerSecond in + // the rampUpTimes, so we should delay the ramp up. + // E.g., rampUpTimeSeconds = 10, rowsPerSecond = 6 + // Then speedDeltaPerSecond = 1 + // + // seconds = 0 1 2 3 4 5 6 7 8 9 10 11 + // speed = 0 0 0 0 0 1 2 3 4 5 6 6 + // end value = 0 0 0 0 0 1 3 6 10 15 21 27 + val validSeconds = math.max(0, math.min(seconds, seconds - (rampUpTimeSeconds - rowsPerSecond))) + if (seconds <= rampUpTimeSeconds) { // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds + if (validSeconds % 2 == 1) { + (validSeconds + 1) / 2 * speedDeltaPerSecond * validSeconds } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) + validSeconds / 2 * speedDeltaPerSecond * (validSeconds + 1) } } else { // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index ff14ec38e66a8..60aad0c662b7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -190,6 +190,24 @@ class RateSourceSuite extends StreamTest { assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 5) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 5) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 5) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 5) === 6) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 5, rampUpTimeSeconds = 5) === 10) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 5, rampUpTimeSeconds = 5) === 15) + assert(valueAtSecond(seconds = 6, rowsPerSecond = 5, rampUpTimeSeconds = 5) === 20) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 0) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 0) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 1) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 3) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 6) + assert(valueAtSecond(seconds = 6, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 10) + assert(valueAtSecond(seconds = 7, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 15) + assert(valueAtSecond(seconds = 8, rowsPerSecond = 5, rampUpTimeSeconds = 7) === 20) } test("rampUpTime") {