diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index d406749c9b4f1..98821c71c7aee 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -20,30 +20,34 @@ package org.apache.spark.streaming.mqtt import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { + + private val batchDuration = Milliseconds(500) + private val master = "local[2]" + private val framework = this.getClass.getSimpleName + private val topic = "def" - private val topic = "topic" private var ssc: StreamingContext = _ private var MQTTTestUtils: MQTTTestUtils = _ - override def beforeAll(): Unit = { + before { + ssc = new StreamingContext(master, framework, batchDuration) MQTTTestUtils = new MQTTTestUtils MQTTTestUtils.setup() } - override def afterAll(): Unit = { + after { if (ssc != null) { ssc.stop() ssc = null } - if (MQTTTestUtils != null) { MQTTTestUtils.teardown() MQTTTestUtils = null @@ -51,10 +55,7 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterA } test("mqtt input stream") { - val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) - ssc = new StreamingContext(sparkConf, Milliseconds(500)) val sendMessage = "MQTT demo for spark streaming" - val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic, StorageLevel.MEMORY_ONLY) @@ -65,6 +66,9 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterA receiveMessage } } + + MQTTTestUtils.registerStreamingListener(ssc) + ssc.start() // wait for the receiver to start before publishing data, or we risk failing diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala index e48760a0a08f4..6c85019ae0723 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.language.postfixOps +import com.google.common.base.Charsets.UTF_8 import org.apache.activemq.broker.{BrokerService, TransportConnector} import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ @@ -46,6 +47,8 @@ private class MQTTTestUtils extends Logging { private var broker: BrokerService = _ private var connector: TransportConnector = _ + private var receiverStartedLatch = new CountDownLatch(1) + def brokerUri: String = { s"$brokerHost:$brokerPort" } @@ -69,6 +72,8 @@ private class MQTTTestUtils extends Logging { connector.stop() connector = null } + Utils.deleteRecursively(persistenceDir) + receiverStartedLatch = null } private def findFreePort(): Int = { @@ -88,7 +93,7 @@ private class MQTTTestUtils extends Logging { client.connect() if (client.isConnected) { val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) + val message = new MqttMessage(data.getBytes(UTF_8)) message.setQos(1) message.setRetained(true) @@ -110,27 +115,37 @@ private class MQTTTestUtils extends Logging { } /** - * Block until at least one receiver has started or timeout occurs. + * Call this one before starting StreamingContext so that we won't miss the + * StreamingListenerReceiverStarted event. */ - def waitForReceiverToStart(ssc: StreamingContext) : Unit = { - val latch = new CountDownLatch(1) + def registerStreamingListener(jssc: JavaStreamingContext): Unit = { + registerStreamingListener(jssc.ssc) + } + + /** + * Call this one before starting StreamingContext so that we won't miss the + * StreamingListenerReceiverStarted event. + */ + def registerStreamingListener(ssc: StreamingContext): Unit = { ssc.addStreamingListener(new StreamingListener { override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() + receiverStartedLatch.countDown() } }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") } - def waitForReceiverToStart(jssc: JavaStreamingContext) : Unit = { - val latch = new CountDownLatch(1) - jssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) + /** + * Block until at least one receiver has started or timeout occurs. + */ + def waitForReceiverToStart(jssc: JavaStreamingContext): Unit = { + waitForReceiverToStart(jssc.ssc) + } - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") + /** + * Block until at least one receiver has started or timeout occurs. + */ + def waitForReceiverToStart(ssc: StreamingContext): Unit = { + assert( + receiverStartedLatch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") } } diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index b27577f878618..77f9ccf0b114a 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -863,31 +863,18 @@ def getOutput(_, rdd): self.ssc.start() return result - def _publishData(self, topic, data): - start_time = time.time() - while True: - try: - self._MQTTTestUtils.publishData(topic, data) - break - except: - if time.time() - start_time < self.timeout: - time.sleep(0.01) - else: - raise - - def _validateStreamResult(self, sendData, result): - receiveData = ''.join(result[0]) - self.assertEqual(sendData, receiveData) - def test_mqtt_stream(self): """Test the Python MQTT stream API.""" sendData = "MQTT demo for spark streaming" topic = self._randomTopic() + self._MQTTTestUtils.registerStreamingListener(self.ssc._jssc) result = self._startContext(topic) self._MQTTTestUtils.waitForReceiverToStart(self.ssc._jssc) - self._publishData(topic, sendData) - self.wait_for(result, len(sendData)) - self._validateStreamResult(sendData, result) + self._MQTTTestUtils.publishData(topic, sendData) + self.wait_for(result, 1) + # Because "publishData" sends duplicate messages, here we should use > 0 + self.assertTrue(len(result) > 0) + self.assertEqual(sendData, result[0]) def search_kafka_assembly_jar():