From 8466916cec3ce6ebba8c3c2c35f7ad4c74f90e66 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 11:51:54 -0700 Subject: [PATCH] support checkpoint --- python/pyspark/streaming/context.py | 7 +- python/pyspark/streaming/util.py | 28 +++++- .../spark/streaming/StreamingContext.scala | 2 +- .../streaming/api/python/PythonDStream.scala | 87 +++++++++++++++---- 4 files changed, 101 insertions(+), 23 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ae4a1d5b6b069..da645a6201503 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -19,11 +19,11 @@ from py4j.java_gateway import java_import from pyspark import RDD -from pyspark.serializers import UTF8Deserializer +from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream -from pyspark.streaming.util import RDDFunction +from pyspark.streaming.util import RDDFunction, RDDFunctionSerializer __all__ = ["StreamingContext"] @@ -100,6 +100,9 @@ def _initialize_context(self, sc, duration): java_import(self._jvm, "org.apache.spark.streaming.*") java_import(self._jvm, "org.apache.spark.streaming.api.java.*") java_import(self._jvm, "org.apache.spark.streaming.api.python.*") + # register serializer for RDDFunction + ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer()) + self._jvm.PythonDStream.registerSerializer(ser) return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) def _jduration(self, seconds): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 4838ec6c8c6e9..c15f9d98c1866 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -16,6 +16,7 @@ # from datetime import datetime +import traceback from pyspark.rdd import RDD @@ -47,7 +48,6 @@ def call(self, milliseconds, jrdds): if r: return r._jrdd except Exception: - import traceback traceback.print_exc() def __repr__(self): @@ -57,6 +57,32 @@ class Java: implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] +class RDDFunctionSerializer(object): + def __init__(self, ctx, serializer): + self.ctx = ctx + self.serializer = serializer + + def dumps(self, id): + try: + func = self.ctx._gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps((func.func, func.deserializers))) + except Exception: + traceback.print_exc() + + def loads(self, bytes): + try: + f, deserializers = self.serializer.loads(str(bytes)) + return RDDFunction(self.ctx, f, *deserializers) + except Exception: + traceback.print_exc() + + def __repr__(self): + return "RDDFunctionSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonRDDFunctionSerializer'] + + def rddToFileName(prefix, suffix, time): """ Return string prefix-time(.suffix) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ef7631788f26d..5a8eef1372e23 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -413,7 +413,7 @@ class StreamingContext private[streaming] ( dstreams: Seq[DStream[_]], transformFunc: (Seq[RDD[_]], Time) => RDD[T] ): DStream[T] = { - new TransformedDStream[T](dstreams, transformFunc) + new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) } /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 4a52ce1c4f43a..ddbbf107abb3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.api.python +import java.io.{ObjectInputStream, ObjectOutputStream} +import java.lang.reflect.Proxy import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import scala.collection.mutable import org.apache.spark.api.java._ import org.apache.spark.api.python._ @@ -35,14 +36,14 @@ import org.apache.spark.streaming.api.java._ * Interface for Python callback function with three arguments */ private[python] trait PythonRDDFunction { - // callback in Python def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } /** * Wrapper for PythonRDDFunction + * TODO: support checkpoint */ -private[python] class RDDFunction(pfunc: PythonRDDFunction) +private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction) extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { @@ -58,23 +59,47 @@ private[python] class RDDFunction(pfunc: PythonRDDFunction) def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { pfunc.call(time.milliseconds, rdds) } -} + private def writeObject(out: ObjectOutputStream): Unit = { + assert(PythonDStream.serializer != null, "Serializer has not been registered!") + val bytes = PythonDStream.serializer.serialize(pfunc) + out.writeInt(bytes.length) + out.write(bytes) + } + + private def readObject(in: ObjectInputStream): Unit = { + assert(PythonDStream.serializer != null, "Serializer has not been registered!") + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + pfunc = PythonDStream.serializer.deserialize(bytes) + } +} /** - * Base class for PythonDStream with some common methods + * Inferface for Python Serializer to serialize PythonRDDFunction */ -private[python] -abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction) - extends DStream[Array[Byte]] (parent.ssc) { - - val func = new RDDFunction(pfunc) - - override def dependencies = List(parent) +private[python] trait PythonRDDFunctionSerializer { + def dumps(id: String): Array[Byte] // + def loads(bytes: Array[Byte]): PythonRDDFunction +} - override def slideDuration: Duration = parent.slideDuration +/** + * Wrapper for PythonRDDFunctionSerializer + */ +private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) { + def serialize(func: PythonRDDFunction): Array[Byte] = { + // get the id of PythonRDDFunction in py4j + val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) + val f = h.getClass().getDeclaredField("id"); + f.setAccessible(true); + val id = f.get(h).asInstanceOf[String]; + pser.dumps(id) + } - val asJavaDStream = JavaDStream.fromDStream(this) + def deserialize(bytes: Array[Byte]): PythonRDDFunction = { + pser.loads(bytes) + } } /** @@ -82,6 +107,14 @@ abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction) */ private[python] object PythonDStream { + // A serializer in Python, used to serialize PythonRDDFunction + var serializer: RDDFunctionSerializer = _ + + // Register a serializer from Python, should be called during initialization + def registerSerializer(ser: PythonRDDFunctionSerializer) = { + serializer = new RDDFunctionSerializer(ser) + } + // convert Option[RDD[_]] to JavaRDD, handle null gracefully def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { if (rdd.isDefined) { @@ -123,6 +156,22 @@ private[python] object PythonDStream { } } +/** + * Base class for PythonDStream with some common methods + */ +private[python] +abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonRDDFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new RDDFunction(pfunc) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + /** * Transformed DStream in Python. * @@ -130,7 +179,7 @@ private[python] object PythonDStream { * as an template for future use, this can reduce the Python callbacks. */ private[python] -class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, +class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonRDDFunction, var reuse: Boolean = false) extends PythonDStream(parent, pfunc) { @@ -170,7 +219,7 @@ class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, */ private[python] class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], - pfunc: PythonRDDFunction) + @transient pfunc: PythonRDDFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new RDDFunction(pfunc) @@ -190,7 +239,7 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], * similar to StateDStream */ private[python] -class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunction) +class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonRDDFunction) extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) @@ -212,8 +261,8 @@ class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunc */ private[python] class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], - preduceFunc: PythonRDDFunction, - pinvReduceFunc: PythonRDDFunction, + @transient preduceFunc: PythonRDDFunction, + @transient pinvReduceFunc: PythonRDDFunction, _windowDuration: Duration, _slideDuration: Duration ) extends PythonStateDStream(parent, preduceFunc) {