Skip to content

Commit

Permalink
Register StreamingListerner before starting StreamingContext; Revert …
Browse files Browse the repository at this point in the history
…unncessary changes; fix the python unit test
  • Loading branch information
zsxwing committed Jul 26, 2015
1 parent a6747cb commit d07f454
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,42 @@ package org.apache.spark.streaming.mqtt
import scala.concurrent.duration._
import scala.language.postfixOps

import org.scalatest.BeforeAndAfterAll
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}

class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {

private val batchDuration = Milliseconds(500)
private val master = "local[2]"
private val framework = this.getClass.getSimpleName
private val topic = "def"

private val topic = "topic"
private var ssc: StreamingContext = _
private var MQTTTestUtils: MQTTTestUtils = _

override def beforeAll(): Unit = {
before {
ssc = new StreamingContext(master, framework, batchDuration)
MQTTTestUtils = new MQTTTestUtils
MQTTTestUtils.setup()
}

override def afterAll(): Unit = {
after {
if (ssc != null) {
ssc.stop()
ssc = null
}

if (MQTTTestUtils != null) {
MQTTTestUtils.teardown()
MQTTTestUtils = null
}
}

test("mqtt input stream") {
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
ssc = new StreamingContext(sparkConf, Milliseconds(500))
val sendMessage = "MQTT demo for spark streaming"

val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic,
StorageLevel.MEMORY_ONLY)

Expand All @@ -65,6 +66,9 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterA
receiveMessage
}
}

MQTTTestUtils.registerStreamingListener(ssc)

ssc.start()

// wait for the receiver to start before publishing data, or we risk failing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.language.postfixOps

import com.google.common.base.Charsets.UTF_8
import org.apache.activemq.broker.{BrokerService, TransportConnector}
import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
Expand All @@ -46,6 +47,8 @@ private class MQTTTestUtils extends Logging {
private var broker: BrokerService = _
private var connector: TransportConnector = _

private var receiverStartedLatch = new CountDownLatch(1)

def brokerUri: String = {
s"$brokerHost:$brokerPort"
}
Expand All @@ -69,6 +72,8 @@ private class MQTTTestUtils extends Logging {
connector.stop()
connector = null
}
Utils.deleteRecursively(persistenceDir)
receiverStartedLatch = null
}

private def findFreePort(): Int = {
Expand All @@ -88,7 +93,7 @@ private class MQTTTestUtils extends Logging {
client.connect()
if (client.isConnected) {
val msgTopic = client.getTopic(topic)
val message = new MqttMessage(data.getBytes("utf-8"))
val message = new MqttMessage(data.getBytes(UTF_8))
message.setQos(1)
message.setRetained(true)

Expand All @@ -110,27 +115,37 @@ private class MQTTTestUtils extends Logging {
}

/**
* Block until at least one receiver has started or timeout occurs.
* Call this one before starting StreamingContext so that we won't miss the
* StreamingListenerReceiverStarted event.
*/
def waitForReceiverToStart(ssc: StreamingContext) : Unit = {
val latch = new CountDownLatch(1)
def registerStreamingListener(jssc: JavaStreamingContext): Unit = {
registerStreamingListener(jssc.ssc)
}

/**
* Call this one before starting StreamingContext so that we won't miss the
* StreamingListenerReceiverStarted event.
*/
def registerStreamingListener(ssc: StreamingContext): Unit = {
ssc.addStreamingListener(new StreamingListener {
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
latch.countDown()
receiverStartedLatch.countDown()
}
})

assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
}

def waitForReceiverToStart(jssc: JavaStreamingContext) : Unit = {
val latch = new CountDownLatch(1)
jssc.addStreamingListener(new StreamingListener {
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
latch.countDown()
}
})
/**
* Block until at least one receiver has started or timeout occurs.
*/
def waitForReceiverToStart(jssc: JavaStreamingContext): Unit = {
waitForReceiverToStart(jssc.ssc)
}

assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
/**
* Block until at least one receiver has started or timeout occurs.
*/
def waitForReceiverToStart(ssc: StreamingContext): Unit = {
assert(
receiverStartedLatch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
}
}
25 changes: 6 additions & 19 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,31 +863,18 @@ def getOutput(_, rdd):
self.ssc.start()
return result

def _publishData(self, topic, data):
start_time = time.time()
while True:
try:
self._MQTTTestUtils.publishData(topic, data)
break
except:
if time.time() - start_time < self.timeout:
time.sleep(0.01)
else:
raise

def _validateStreamResult(self, sendData, result):
receiveData = ''.join(result[0])
self.assertEqual(sendData, receiveData)

def test_mqtt_stream(self):
"""Test the Python MQTT stream API."""
sendData = "MQTT demo for spark streaming"
topic = self._randomTopic()
self._MQTTTestUtils.registerStreamingListener(self.ssc._jssc)
result = self._startContext(topic)
self._MQTTTestUtils.waitForReceiverToStart(self.ssc._jssc)
self._publishData(topic, sendData)
self.wait_for(result, len(sendData))
self._validateStreamResult(sendData, result)
self._MQTTTestUtils.publishData(topic, sendData)
self.wait_for(result, 1)
# Because "publishData" sends duplicate messages, here we should use > 0
self.assertTrue(len(result) > 0)
self.assertEqual(sendData, result[0])


def search_kafka_assembly_jar():
Expand Down

0 comments on commit d07f454

Please sign in to comment.