Skip to content

Commit

Permalink
implemented Python-friendly class
Browse files Browse the repository at this point in the history
  • Loading branch information
prabeesh committed Jul 7, 2015
1 parent a11968b commit 9767d82
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.streaming.mqtt

import scala.reflect.ClassTag

import org.apache.spark.api.java.function.Function
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream}
Expand Down Expand Up @@ -74,3 +75,24 @@ object MQTTUtils {
createStream(jssc.ssc, brokerUrl, topic, storageLevel)
}
}

/**
* This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and
* function so that it can be easily instantiated and called from Python's MQTTUtils.
*/
private class MQTTUtilsPythonHelper {

def createStream(
jssc: JavaStreamingContext,
brokerUrl: String,
topic: String,
storageLevel: StorageLevel
): JavaDStream[Array[Byte]] = {
val dstream = MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel)
dstream.map(new Function[String, Array[Byte]] {
override def call(data: String): Array[Byte] = {
data.getBytes("UTF-8")
}
})
}
}
10 changes: 5 additions & 5 deletions python/pyspark/streaming/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# limitations under the License.
#

from py4j.java_collections import MapConverter
from py4j.java_gateway import java_import, Py4JJavaError
from py4j.java_gateway import Py4JJavaError

from pyspark.storagelevel import StorageLevel
from pyspark.serializers import UTF8Deserializer
Expand All @@ -38,12 +37,13 @@ def createStream(ssc, brokerUrl, topic,
:param storageLevel: RDD storage level.
:return: A DStream object
"""
java_import(ssc._jvm, "org.apache.spark.streaming.mqtt.MQTTUtils")

jlevel = ssc._sc._getJavaStorageLevel(storageLevel)

try:
jstream = ssc._jvm.MQTTUtils.createStream(ssc._jssc, brokerUrl, topic, jlevel)
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
MQTTUtils._printErrorMsg(ssc.sparkContext)
Expand Down

0 comments on commit 9767d82

Please sign in to comment.