From 9767d82b46037aeb9d8af8e825a1e6a237ca5673 Mon Sep 17 00:00:00 2001 From: Prabeesh K Date: Mon, 6 Jul 2015 19:40:03 +0400 Subject: [PATCH] implemented Python-friendly class --- .../spark/streaming/mqtt/MQTTUtils.scala | 22 +++++++++++++++++++ python/pyspark/streaming/mqtt.py | 10 ++++----- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1142d0f56ba34..de8f5650fbe55 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.mqtt import scala.reflect.ClassTag +import org.apache.spark.api.java.function.Function import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream} @@ -74,3 +75,24 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[Array[Byte]] = { + val dstream = MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + dstream.map(new Function[String, Array[Byte]] { + override def call(data: String): Array[Byte] = { + data.getBytes("UTF-8") + } + }) + } +} diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index 9dee0f74589a4..f06598971c548 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -15,8 +15,7 @@ # limitations under the License. # -from py4j.java_collections import MapConverter -from py4j.java_gateway import java_import, Py4JJavaError +from py4j.java_gateway import Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import UTF8Deserializer @@ -38,12 +37,13 @@ def createStream(ssc, brokerUrl, topic, :param storageLevel: RDD storage level. :return: A DStream object """ - java_import(ssc._jvm, "org.apache.spark.streaming.mqtt.MQTTUtils") - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) try: - jstream = ssc._jvm.MQTTUtils.createStream(ssc._jssc, brokerUrl, topic, jlevel) + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) except Py4JJavaError as e: if 'ClassNotFoundException' in str(e.java_exception): MQTTUtils._printErrorMsg(ssc.sparkContext)