Skip to content

Commit

Permalink
support checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 30, 2014
1 parent 9a16bd1 commit 8466916
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 23 deletions.
7 changes: 5 additions & 2 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from py4j.java_gateway import java_import

from pyspark import RDD
from pyspark.serializers import UTF8Deserializer
from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.dstream import DStream
from pyspark.streaming.util import RDDFunction
from pyspark.streaming.util import RDDFunction, RDDFunctionSerializer

__all__ = ["StreamingContext"]

Expand Down Expand Up @@ -100,6 +100,9 @@ def _initialize_context(self, sc, duration):
java_import(self._jvm, "org.apache.spark.streaming.*")
java_import(self._jvm, "org.apache.spark.streaming.api.java.*")
java_import(self._jvm, "org.apache.spark.streaming.api.python.*")
# register serializer for RDDFunction
ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer())
self._jvm.PythonDStream.registerSerializer(ser)
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))

def _jduration(self, seconds):
Expand Down
28 changes: 27 additions & 1 deletion python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from datetime import datetime
import traceback

from pyspark.rdd import RDD

Expand Down Expand Up @@ -47,7 +48,6 @@ def call(self, milliseconds, jrdds):
if r:
return r._jrdd
except Exception:
import traceback
traceback.print_exc()

def __repr__(self):
Expand All @@ -57,6 +57,32 @@ class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']


class RDDFunctionSerializer(object):
def __init__(self, ctx, serializer):
self.ctx = ctx
self.serializer = serializer

def dumps(self, id):
try:
func = self.ctx._gateway.gateway_property.pool[id]
return bytearray(self.serializer.dumps((func.func, func.deserializers)))
except Exception:
traceback.print_exc()

def loads(self, bytes):
try:
f, deserializers = self.serializer.loads(str(bytes))
return RDDFunction(self.ctx, f, *deserializers)
except Exception:
traceback.print_exc()

def __repr__(self):
return "RDDFunctionSerializer(%s)" % self.serializer

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunctionSerializer']


def rddToFileName(prefix, suffix, time):
"""
Return string prefix-time(.suffix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ class StreamingContext private[streaming] (
dstreams: Seq[DStream[_]],
transformFunc: (Seq[RDD[_]], Time) => RDD[T]
): DStream[T] = {
new TransformedDStream[T](dstreams, transformFunc)
new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc))
}

/** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.streaming.api.python

import java.io.{ObjectInputStream, ObjectOutputStream}
import java.lang.reflect.Proxy
import java.util.{ArrayList => JArrayList, List => JList}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.api.java._
import org.apache.spark.api.python._
Expand All @@ -35,14 +36,14 @@ import org.apache.spark.streaming.api.java._
* Interface for Python callback function with three arguments
*/
private[python] trait PythonRDDFunction {
// callback in Python
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
}

/**
* Wrapper for PythonRDDFunction
* TODO: support checkpoint
*/
private[python] class RDDFunction(pfunc: PythonRDDFunction)
private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction)
extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable {

def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
Expand All @@ -58,30 +59,62 @@ private[python] class RDDFunction(pfunc: PythonRDDFunction)
def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
pfunc.call(time.milliseconds, rdds)
}
}

private def writeObject(out: ObjectOutputStream): Unit = {
assert(PythonDStream.serializer != null, "Serializer has not been registered!")
val bytes = PythonDStream.serializer.serialize(pfunc)
out.writeInt(bytes.length)
out.write(bytes)
}

private def readObject(in: ObjectInputStream): Unit = {
assert(PythonDStream.serializer != null, "Serializer has not been registered!")
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
pfunc = PythonDStream.serializer.deserialize(bytes)
}
}

/**
* Base class for PythonDStream with some common methods
* Inferface for Python Serializer to serialize PythonRDDFunction
*/
private[python]
abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction)
extends DStream[Array[Byte]] (parent.ssc) {

val func = new RDDFunction(pfunc)

override def dependencies = List(parent)
private[python] trait PythonRDDFunctionSerializer {
def dumps(id: String): Array[Byte] //
def loads(bytes: Array[Byte]): PythonRDDFunction
}

override def slideDuration: Duration = parent.slideDuration
/**
* Wrapper for PythonRDDFunctionSerializer
*/
private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) {
def serialize(func: PythonRDDFunction): Array[Byte] = {
// get the id of PythonRDDFunction in py4j
val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
val f = h.getClass().getDeclaredField("id");
f.setAccessible(true);
val id = f.get(h).asInstanceOf[String];
pser.dumps(id)
}

val asJavaDStream = JavaDStream.fromDStream(this)
def deserialize(bytes: Array[Byte]): PythonRDDFunction = {
pser.loads(bytes)
}
}

/**
* Helper functions
*/
private[python] object PythonDStream {

// A serializer in Python, used to serialize PythonRDDFunction
var serializer: RDDFunctionSerializer = _

// Register a serializer from Python, should be called during initialization
def registerSerializer(ser: PythonRDDFunctionSerializer) = {
serializer = new RDDFunctionSerializer(ser)
}

// convert Option[RDD[_]] to JavaRDD, handle null gracefully
def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = {
if (rdd.isDefined) {
Expand Down Expand Up @@ -123,14 +156,30 @@ private[python] object PythonDStream {
}
}

/**
* Base class for PythonDStream with some common methods
*/
private[python]
abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonRDDFunction)
extends DStream[Array[Byte]] (parent.ssc) {

val func = new RDDFunction(pfunc)

override def dependencies = List(parent)

override def slideDuration: Duration = parent.slideDuration

val asJavaDStream = JavaDStream.fromDStream(this)
}

/**
* Transformed DStream in Python.
*
* If `reuse` is true and the result of the `func` is an PythonRDD, then it will cache it
* as an template for future use, this can reduce the Python callbacks.
*/
private[python]
class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction,
class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonRDDFunction,
var reuse: Boolean = false)
extends PythonDStream(parent, pfunc) {

Expand Down Expand Up @@ -170,7 +219,7 @@ class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction,
*/
private[python]
class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_],
pfunc: PythonRDDFunction)
@transient pfunc: PythonRDDFunction)
extends DStream[Array[Byte]] (parent.ssc) {

val func = new RDDFunction(pfunc)
Expand All @@ -190,7 +239,7 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_],
* similar to StateDStream
*/
private[python]
class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunction)
class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonRDDFunction)
extends PythonDStream(parent, reduceFunc) {

super.persist(StorageLevel.MEMORY_ONLY)
Expand All @@ -212,8 +261,8 @@ class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunc
*/
private[python]
class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
preduceFunc: PythonRDDFunction,
pinvReduceFunc: PythonRDDFunction,
@transient preduceFunc: PythonRDDFunction,
@transient pinvReduceFunc: PythonRDDFunction,
_windowDuration: Duration,
_slideDuration: Duration
) extends PythonStateDStream(parent, preduceFunc) {
Expand Down

0 comments on commit 8466916

Please sign in to comment.