diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 8df542b367d27..f663def4c0511 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -34,12 +34,32 @@ import org.apache.spark.{Logging, SparkConf} */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate) + // treated as an upper limit + private val maxRateLimit = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue) + private lazy val rateLimiter = GuavaRateLimiter.create(maxRateLimit.toDouble) def waitToPush() { - if (desiredRate > 0) { - rateLimiter.acquire() - } + rateLimiter.acquire() } + + /** + * Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}. + */ + def getCurrentLimit: Long = + rateLimiter.getRate.toLong + + /** + * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by + * {{{spark.streaming.receiver.maxRate}}}, even if `newRate` is higher than that. + * + * @param newRate A new rate in events per second. It has no effect if it's 0 or negative. + */ + private[receiver] def updateRate(newRate: Long): Unit = + if (newRate > 0) { + if (maxRateLimit > 0) { + rateLimiter.setRate(newRate.min(maxRateLimit)) + } else { + rateLimiter.setRate(newRate) + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5b5a3fe648602..7504fa44d9fae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -271,7 +271,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable } /** Get the attached executor. */ - private def executor = { + private def executor: ReceiverSupervisor = { assert(executor_ != null, "Executor has not been attached to this receiver") executor_ } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index 7bf3c33319491..1eb55affaa9d0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -23,4 +23,5 @@ import org.apache.spark.streaming.Time private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage - +private[streaming] case class UpdateRateLimit(elementsPerSecond: Long) + extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 6467029a277b2..a7c220f426ecf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -59,6 +59,9 @@ private[streaming] abstract class ReceiverSupervisor( /** Time between a receiver is stopped and started again */ private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) + /** The current maximum rate limit for this receiver. */ + private[streaming] def getCurrentRateLimit: Option[Long] = None + /** Exception associated with the stopping of the receiver */ @volatile protected var stoppingError: Throwable = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index f6ba66b3ae036..2f6841ee8879c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -77,6 +77,9 @@ private[streaming] class ReceiverSupervisorImpl( case CleanupOldBlocks(threshTime) => logDebug("Received delete old batch signal") cleanupOldBlocks(threshTime) + case UpdateRateLimit(eps) => + logInfo(s"Received a new rate limit: $eps.") + blockGenerator.updateRate(eps) } }) @@ -98,6 +101,9 @@ private[streaming] class ReceiverSupervisorImpl( } }, streamId, env.conf) + override private[streaming] def getCurrentRateLimit: Option[Long] = + Some(blockGenerator.getCurrentLimit) + /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { blockGenerator.addData(data) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6910d81d9866e..9cc6ffcd12f61 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, - StopReceiver} + StopReceiver, UpdateRateLimit} import org.apache.spark.util.SerializableConfiguration /** @@ -226,6 +226,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logError(s"Deregistered receiver for stream $streamId: $messageWithError") } + /** Update a receiver's maximum ingestion rate */ + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { + for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint)) { + eP.send(UpdateRateLimit(newRate)) + } + } + /** Add new blocks for the given stream */ private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { receivedBlockTracker.addBlock(receivedBlockInfo) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala new file mode 100644 index 0000000000000..c6330eb3673fb --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite + +/** Testsuite for testing the network receiver behavior */ +class RateLimiterSuite extends SparkFunSuite { + + test("rate limiter initializes even without a maxRate set") { + val conf = new SparkConf() + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit == 105) + } + + test("rate limiter updates when below maxRate") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit == 105) + } + + test("rate limiter stays below maxRate despite large updates") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit === 100) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index a6e783861dbe6..aadb7231757b8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.streaming.scheduler +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming._ import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.receiver._ import org.apache.spark.util.Utils +import org.apache.spark.streaming.dstream.InputDStream +import scala.reflect.ClassTag +import org.apache.spark.streaming.dstream.ReceiverInputDStream /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { @@ -72,8 +78,64 @@ class ReceiverTrackerSuite extends TestSuiteBase { assert(locations(0).length === 1) assert(locations(3).length === 1) } + + test("Receiver tracker - propagates rate limit") { + object ReceiverStartedWaiter extends StreamingListener { + @volatile + var started = false + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + started = true + } + } + + ssc.addStreamingListener(ReceiverStartedWaiter) + ssc.scheduler.listenerBus.start(ssc.sc) + + val newRateLimit = 100L + val inputDStream = new RateLimitInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(ReceiverStartedWaiter.started) + } + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + // this is an async message, we need to wait a bit for it to be processed + eventually(timeout(3 seconds)) { + assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + } + } } +/** An input DStream with a hard-coded receiver that gives access to internals for testing. */ +private class RateLimitInputDStream(@transient ssc_ : StreamingContext) + extends ReceiverInputDStream[Int](ssc_) { + + override def getReceiver(): DummyReceiver = SingletonDummyReceiver + + def getCurrentRateLimit: Option[Long] = { + invokeExecutorMethod.getCurrentRateLimit + } + + private def invokeExecutorMethod: ReceiverSupervisor = { + val c = classOf[Receiver[_]] + val ex = c.getDeclaredMethod("executor") + ex.setAccessible(true) + ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor] + } +} + +/** + * A Receiver as an object so we can read its rate limit. + * + * @note It's necessary to be a top-level object, or else serialization would create another + * one on the executor side and we won't be able to read its rate limit. + */ +private object SingletonDummyReceiver extends DummyReceiver + /** * Dummy receiver implementation */