From c4374ebde078d893adcb594237786763d9942d8c Mon Sep 17 00:00:00 2001 From: Gerard Date: Mon, 30 Apr 2018 15:27:01 +0200 Subject: [PATCH] implement ramp-up using the acceleration function --- .../sources/RateStreamProvider.scala | 26 +-- .../sources/RateStreamProviderSuite.scala | 149 +++++++++++++++--- 2 files changed, 130 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6bdd492f0cb35..7934f56a0547c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -101,25 +101,11 @@ object RateStreamProvider { /** Calculate the end value we will emit at the time `seconds`. */ def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } + val delta = rowsPerSecond.toDouble / rampUpTimeSeconds + val rampUpSeconds = if (seconds <= rampUpTimeSeconds) seconds else rampUpTimeSeconds + val afterRampUpSeconds = if (seconds > rampUpTimeSeconds ) seconds - rampUpTimeSeconds else 0 + // Use classic distance formula based on acceleration: ut + ½at2 + val rampUpValue = Math.floor(rampUpSeconds * rampUpSeconds * delta / 2).toLong + rampUpValue + afterRampUpSeconds * rowsPerSecond } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index ff14ec38e66a8..3d91b8b349969 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -173,26 +173,72 @@ class RateSourceSuite extends StreamTest { assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) } - test("valueAtSecond") { + test("valueAtSecond without ramp-up") { import RateStreamProvider._ + val rowsPerSec = Seq(1,10,50,100,1000,10000) + val secs = Seq(1, 10, 100, 1000, 10000, 100000) + for { + sec <- secs + rps <- rowsPerSec + } yield { + assert(valueAtSecond(seconds = sec, rowsPerSecond = rps, rampUpTimeSeconds = 0) === sec * rps) + } + } - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + test("valueAtSecond with ramp-up") { + import RateStreamProvider._ + val rowsPerSec = Seq(1, 5, 10, 50, 100, 1000, 10000) + val rampUpSec = Seq(10, 100, 1000) + + // for any combination, value at zero = 0 + for { + rps <- rowsPerSec + rampUp <- rampUpSec + } yield { + assert(valueAtSecond(seconds = 0, rowsPerSecond = rps, rampUpTimeSeconds = rampUp) === 0) + } - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + // for any combination, the value at half-way between (0, rampUpSeconds) > 0 + for { + rps <- rowsPerSec + rampUp <- rampUpSec + if rampUp/2 > 0 + } yield { + assert( + valueAtSecond(seconds = rampUp/2 , rowsPerSecond = rps, rampUpTimeSeconds = rampUp) > 0 + ) + } - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + // The rate increases as the time gets to the ramp-up value and stabilizes to rowsPerSecond + for { + rps <- rowsPerSec + rampUp <- rampUpSec + } yield { + val valueAtSec: Int => Long = i => + valueAtSecond(i , rowsPerSecond = rps, rampUpTimeSeconds = rampUp) + + val valuePerSecond = (0 to rampUp).map(i => valueAtSec(i)) + // calculate the actual rate + val diffs = valuePerSecond.zip(valuePerSecond.tail).map{case (x,x1) => x1-x} + // there should be values + assert(diffs.sum > 0) + // Rate values should be increasing + assert(diffs.forall(x => x >= 0 )) + + // The rate after ramp up is the configured rate per second + assert(valueAtSec(rampUp + 1) - valueAtSec(rampUp) === rps ) + } } - test("rampUpTime") { + // evenly distributes numValues over a second in milisecond intervals starting at startValue + private def distributeValues(startValue: Int, numValues: Int, atSecond:Long): Seq[(Long, Long)]= { + val offset = atSecond * 1000 + (0 until numValues).map{v => + val timeMills = offset + Math.round(v*1000.0/numValues) + timeMills -> (v + startValue).toLong} + } + + test("rampUpTime when rowsPerSecond > rampUpTime") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "10") @@ -200,28 +246,81 @@ class RateSourceSuite extends StreamTest { .option("useManualClock", "true") .load() .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) + .map{case (ts, value) => (ts.getTime, value)} + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(Seq((0,0)): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 1, numValues = 4, atSecond = 1): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 5, numValues = 6, atSecond = 2): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 11, numValues = 9, atSecond = 3): _*), + AdvanceRateManualClock(seconds = 1), + // Now we should reach full rate + CheckLastBatch(distributeValues(startValue = 20, numValues = 10, atSecond = 4): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 30, numValues = 10, atSecond = 5): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 40, numValues = 10, atSecond = 6): _*) + ) + } + + test("rampUpTime when rowsPerSecond < rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "4") + .option("rampUpTime", "5s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map{case (ts, value) => (ts.getTime, value)} testStream(input)( AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + CheckLastBatch(), AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + CheckLastBatch((1000,0)), AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 + CheckLastBatch(distributeValues(startValue = 1, numValues = 2, atSecond = 2): _*), AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + CheckLastBatch(distributeValues(startValue = 3, numValues = 3, atSecond = 3): _*), AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + CheckLastBatch(distributeValues(startValue = 6, numValues = 4, atSecond = 4): _*), AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + CheckLastBatch(distributeValues(startValue = 10, numValues = 4, atSecond = 5): _*), AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + CheckLastBatch(distributeValues(startValue = 14, numValues = 4, atSecond = 6): _*) ) } + test("rampUpTime when rowsPerSecond == rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "5") + .option("rampUpTime", "5s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map{case (ts, value) => (ts.getTime, value)} + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 0, numValues = 2, atSecond = 1): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 2, numValues = 2, atSecond = 2): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 4, numValues = 4, atSecond = 3): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 8, numValues = 4, atSecond = 4): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 12, numValues = 5, atSecond = 5): _*), + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(distributeValues(startValue = 17, numValues = 5, atSecond = 6): _*) + ) + } + + test("numPartitions") { val input = spark.readStream .format("rate")