diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 219892dd02648..0cf28c8262523 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.receiver import org.apache.spark.{Logging, SparkConf} -import java.util.concurrent.TimeUnit._ +import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter} /** Provides waitToPush() method to limit the rate at which receivers consume data. * @@ -33,37 +33,13 @@ import java.util.concurrent.TimeUnit._ */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private var lastSyncTime = System.nanoTime - private var messagesWrittenSinceSync = 0L private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate) def waitToPush() { if( desiredRate <= 0 ) { return } - val now = System.nanoTime - val elapsedNanosecs = math.max(now - lastSyncTime, 1) - val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs - if (rate < desiredRate) { - // It's okay to write; just update some variables and return - messagesWrittenSinceSync += 1 - if (now > lastSyncTime + SYNC_INTERVAL) { - // Sync interval has passed; let's resync - lastSyncTime = now - messagesWrittenSinceSync = 1 - } - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - val targetTimeInMillis = messagesWrittenSinceSync.toDouble * 1000 / desiredRate - val elapsedTimeInMillis = elapsedNanosecs / 1000000 - val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis - if (sleepTimeInMillis > 0) { - logTrace("Natural rate is " + rate + " per second but desired rate is " + - desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.") - Thread.sleep(sleepTimeInMillis.toInt) - } - waitToPush() - } + rateLimiter.acquire() } }