Skip to content

Commit

Permalink
Add a test for the propagation of a new rate limit from driver to rec…
Browse files Browse the repository at this point in the history
…eivers.
  • Loading branch information
dragos committed Jul 17, 2015
1 parent 6369b30 commit cd1397d
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
}

/** Get the attached executor. */
private def executor = {
private[streaming] def executor = {
assert(executor_ != null, "Executor has not been attached to this receiver")
executor_
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ private[streaming] abstract class ReceiverSupervisor(
/** Time between a receiver is stopped and started again */
private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000)

/** The current maximum rate limit for this receiver. */
private[streaming] def getCurrentRateLimit: Option[Int] = None

/** Exception associated with the stopping of the receiver */
@volatile protected var stoppingError: Throwable = null

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ private[streaming] class ReceiverSupervisorImpl(
}
}, streamId, env.conf)

override private[streaming] def getCurrentRateLimit: Option[Int] =
Some(blockGenerator.currentRateLimit.get)

/** Push a single record of received data into block generator. */
def pushSingle(data: Any) {
blockGenerator.addData(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,19 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
verifyOutput[W](output, expectedOutput, useSet)
}
}

/**
* Wait until `cond` becomes true, or timeout ms have passed. This method checks the condition
* every 100ms, so it won't wait more than 100ms more than necessary.
*
* @param cond A boolean that should become `true`
* @param timemout How many millis to wait before giving up
*/
def waitUntil(cond: => Boolean, timeout: Int): Unit = {
val start = System.currentTimeMillis()
val end = start + timeout
while ((System.currentTimeMillis() < end) && !cond) {
Thread.sleep(100)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import org.apache.spark.SparkConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.receiver._
import org.apache.spark.util.Utils
import org.apache.spark.streaming.dstream.InputDStream
import scala.reflect.ClassTag
import org.apache.spark.streaming.dstream.ReceiverInputDStream

/** Testsuite for receiver scheduling */
class ReceiverTrackerSuite extends TestSuiteBase {
Expand Down Expand Up @@ -72,15 +75,46 @@ class ReceiverTrackerSuite extends TestSuiteBase {
assert(locations(0).length === 1)
assert(locations(3).length === 1)
}

test("Receiver tracker - propagates rate limit") {
val newRateLimit = 100
val ids = new TestReceiverInputDStream(ssc)
val tracker = new ReceiverTracker(ssc)
tracker.start()
waitUntil(TestDummyReceiver.started, 5000)
tracker.sendRateUpdate(ids.id, newRateLimit)
// this is an async message, we need to wait a bit for it to be processed
waitUntil(ids.getRateLimit.get == newRateLimit, 1000)
assert(ids.getRateLimit.get === newRateLimit)
}
}

/** An input DStream with a hard-coded receiver that gives access to internals for testing. */
private class TestReceiverInputDStream(@transient ssc_ : StreamingContext)
extends ReceiverInputDStream[Int](ssc_) {

override def getReceiver(): DummyReceiver = TestDummyReceiver

def getRateLimit: Option[Int] =
TestDummyReceiver.executor.getCurrentRateLimit
}

/**
* We need the receiver to be an object, otherwise serialization will create another one
* and we won't be able to read its rate limit.
*/
private object TestDummyReceiver extends DummyReceiver

/**
* Dummy receiver implementation
*/
private class DummyReceiver(host: Option[String] = None)
extends Receiver[Int](StorageLevel.MEMORY_ONLY) {

var started = false

def onStart() {
started = true
}

def onStop() {
Expand Down

0 comments on commit cd1397d

Please sign in to comment.