Skip to content

Commit

Permalink
rename RDDFunction to TransformFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 1, 2014
1 parent d328aca commit ff88bec
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 49 deletions.
16 changes: 8 additions & 8 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.dstream import DStream
from pyspark.streaming.util import RDDFunction, RDDFunctionSerializer
from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer

__all__ = ["StreamingContext"]

Expand Down Expand Up @@ -114,10 +114,10 @@ def _ensure_initialized(cls):
java_import(gw.jvm, "org.apache.spark.streaming.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
# register serializer for RDDFunction
# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
cls._transformerSerializer = RDDFunctionSerializer(SparkContext._active_spark_context,
CloudPickleSerializer(), gw)
cls._transformerSerializer = TransformFunctionSerializer(
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer)

@classmethod
Expand Down Expand Up @@ -284,10 +284,10 @@ def transform(self, dstreams, transformFunc):
jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
SparkContext._gateway._gateway_client)
# change the final serializer to sc.serializer
func = RDDFunction(self._sc,
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
*[d._jrdd_deserializer for d in dstreams])
jfunc = self._jvm.RDDFunction(func)
func = TransformFunction(self._sc,
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
*[d._jrdd_deserializer for d in dstreams])
jfunc = self._jvm.TransformFunction(func)
jdstream = self._jssc.transform(jdstreams, jfunc)
return DStream(jdstream, self, self._sc.serializer)

Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pyspark import RDD
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.util import rddToFileName, RDDFunction
from pyspark.streaming.util import rddToFileName, TransformFunction
from pyspark.rdd import portable_hash
from pyspark.resultiterable import ResultIterable

Expand Down Expand Up @@ -154,7 +154,7 @@ def foreachRDD(self, func):
"""
Apply a function to each RDD in this DStream.
"""
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer)
jfunc = TransformFunction(self.ctx, func, self._jrdd_deserializer)
api = self._ssc._jvm.PythonDStream
api.callForeachRDD(self._jdstream, jfunc)

Expand Down Expand Up @@ -292,7 +292,7 @@ def transformWith(self, func, other, keepSerializer=False):
oldfunc = func
func = lambda t, a, b: oldfunc(a, b)
assert func.func_code.co_argcount == 3, "func should take two or three arguments"
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer)
jfunc = TransformFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer)
dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
other._jdstream.dstream(), jfunc)
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer
Expand Down Expand Up @@ -535,9 +535,9 @@ def invReduceFunc(t, a, b):
joined = a.leftOuterJoin(b, numPartitions)
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)

jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
jreduceFunc = TransformFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
if invReduceFunc:
jinvReduceFunc = RDDFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
jinvReduceFunc = TransformFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
else:
jinvReduceFunc = None
if slideDuration is None:
Expand Down Expand Up @@ -568,8 +568,8 @@ def reduceFunc(t, a, b):
state = g.mapPartitions(lambda x: updateFunc(x))
return state.filter(lambda (k, v): v is not None)

jreduceFunc = RDDFunction(self.ctx, reduceFunc,
self.ctx.serializer, self._jrdd_deserializer)
jreduceFunc = TransformFunction(self.ctx, reduceFunc,
self.ctx.serializer, self._jrdd_deserializer)
dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer)

Expand Down Expand Up @@ -609,7 +609,7 @@ def _jdstream(self):
return self._jdstream_val

func = self.func
jfunc = RDDFunction(self.ctx, func, self.prev._jrdd_deserializer)
jfunc = TransformFunction(self.ctx, func, self.prev._jrdd_deserializer)
jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(),
jfunc, self.reuse).asJavaDStream()
self._jdstream_val = jdstream
Expand Down
14 changes: 7 additions & 7 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark import SparkContext, RDD


class RDDFunction(object):
class TransformFunction(object):
"""
This class is for py4j callback.
"""
Expand Down Expand Up @@ -58,13 +58,13 @@ def call(self, milliseconds, jrdds):
traceback.print_exc()

def __repr__(self):
return "RDDFunction(%s)" % self.func
return "TransformFunction(%s)" % self.func

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction']


class RDDFunctionSerializer(object):
class TransformFunctionSerializer(object):
def __init__(self, ctx, serializer, gateway=None):
self.ctx = ctx
self.serializer = serializer
Expand All @@ -80,15 +80,15 @@ def dumps(self, id):
def loads(self, bytes):
try:
f, deserializers = self.serializer.loads(str(bytes))
return RDDFunction(self.ctx, f, *deserializers)
return TransformFunction(self.ctx, f, *deserializers)
except Exception:
traceback.print_exc()

def __repr__(self):
return "RDDFunctionSerializer(%s)" % self.serializer
return "TransformFunctionSerializer(%s)" % self.serializer

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunctionSerializer']
implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']


def rddToFileName(prefix, suffix, time):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ import org.apache.spark.streaming.api.java._
/**
* Interface for Python callback function with three arguments
*/
private[python] trait PythonRDDFunction {
private[python] trait PythonTransformFunction {
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
}

/**
* Wrapper for PythonRDDFunction
* Wrapper for PythonTransformFunction
* TODO: support checkpoint
*/
private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction)
private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction)
extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable {

def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
Expand Down Expand Up @@ -77,27 +77,27 @@ private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction)
}

/**
* Interface for Python Serializer to serialize PythonRDDFunction
* Interface for Python Serializer to serialize PythonTransformFunction
*/
private[python] trait PythonRDDFunctionSerializer {
private[python] trait PythonTransformFunctionSerializer {
def dumps(id: String): Array[Byte] //
def loads(bytes: Array[Byte]): PythonRDDFunction
def loads(bytes: Array[Byte]): PythonTransformFunction
}

/**
* Wrapper for PythonRDDFunctionSerializer
* Wrapper for PythonTransformFunctionSerializer
*/
private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) {
def serialize(func: PythonRDDFunction): Array[Byte] = {
// get the id of PythonRDDFunction in py4j
private[python] class TransformFunctionSerializer(pser: PythonTransformFunctionSerializer) {
def serialize(func: PythonTransformFunction): Array[Byte] = {
// get the id of PythonTransformFunction in py4j
val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
val f = h.getClass().getDeclaredField("id")
f.setAccessible(true)
val id = f.get(h).asInstanceOf[String]
pser.dumps(id)
}

def deserialize(bytes: Array[Byte]): PythonRDDFunction = {
def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
pser.loads(bytes)
}
}
Expand All @@ -107,18 +107,18 @@ private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) {
*/
private[python] object PythonDStream {

// A serializer in Python, used to serialize PythonRDDFunction
var serializer: RDDFunctionSerializer = _
// A serializer in Python, used to serialize PythonTransformFunction
var serializer: TransformFunctionSerializer = _

// Register a serializer from Python, should be called during initialization
def registerSerializer(ser: PythonRDDFunctionSerializer) = {
serializer = new RDDFunctionSerializer(ser)
def registerSerializer(ser: PythonTransformFunctionSerializer) = {
serializer = new TransformFunctionSerializer(ser)
}

// helper function for DStream.foreachRDD(),
// cannot be `foreachRDD`, it will confusing py4j
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonRDDFunction) {
val func = new RDDFunction((pfunc))
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
val func = new TransformFunction((pfunc))
jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
}

Expand All @@ -134,10 +134,10 @@ private[python] object PythonDStream {
* Base class for PythonDStream with some common methods
*/
private[python]
abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonRDDFunction)
abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonTransformFunction)
extends DStream[Array[Byte]] (parent.ssc) {

val func = new RDDFunction(pfunc)
val func = new TransformFunction(pfunc)

override def dependencies = List(parent)

Expand All @@ -153,7 +153,7 @@ abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonRDDFunc
* as an template for future use, this can reduce the Python callbacks.
*/
private[python]
class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonRDDFunction,
class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonTransformFunction,
var reuse: Boolean = false)
extends PythonDStream(parent, pfunc) {

Expand Down Expand Up @@ -193,10 +193,10 @@ class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonRDDF
*/
private[python]
class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_],
@transient pfunc: PythonRDDFunction)
@transient pfunc: PythonTransformFunction)
extends DStream[Array[Byte]] (parent.ssc) {

val func = new RDDFunction(pfunc)
val func = new TransformFunction(pfunc)

override def slideDuration: Duration = parent.slideDuration

Expand All @@ -213,7 +213,7 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_],
* similar to StateDStream
*/
private[python]
class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonRDDFunction)
class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonTransformFunction)
extends PythonDStream(parent, reduceFunc) {

super.persist(StorageLevel.MEMORY_ONLY)
Expand All @@ -235,16 +235,16 @@ class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: Py
*/
private[python]
class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
@transient preduceFunc: PythonRDDFunction,
@transient pinvReduceFunc: PythonRDDFunction,
@transient preduceFunc: PythonTransformFunction,
@transient pinvReduceFunc: PythonTransformFunction,
_windowDuration: Duration,
_slideDuration: Duration
) extends PythonDStream(parent, preduceFunc) {

super.persist(StorageLevel.MEMORY_ONLY)
override val mustCheckpoint = true

val invReduceFunc = new RDDFunction(pinvReduceFunc)
val invReduceFunc = new TransformFunction(pinvReduceFunc)

def windowDuration: Duration = _windowDuration
override def slideDuration: Duration = _slideDuration
Expand Down

0 comments on commit ff88bec

Please sign in to comment.