Skip to content

Commit

Permalink
refactor, combine TransformedRDD, fix reuse PythonRDD, fix union
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 26, 2014
1 parent 9a57685 commit eec401e
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 76 deletions.
7 changes: 1 addition & 6 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
# limitations under the License.
#

import sys
from signal import signal, SIGTERM, SIGINT
import atexit
import time

from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
from pyspark.serializers import UTF8Deserializer
from pyspark.context import SparkContext
from pyspark.streaming.dstream import DStream
from pyspark.streaming.duration import Duration, Seconds
Expand Down
112 changes: 76 additions & 36 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,15 @@
# limitations under the License.
#

from collections import defaultdict
from itertools import chain, ifilter, imap
import operator

from pyspark import RDD
from pyspark.serializers import NoOpSerializer,\
BatchedSerializer, CloudPickleSerializer, pack_long,\
CompressedSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.streaming.util import rddToFileName, RDDFunction
from pyspark.rdd import portable_hash, _parse_memory
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.streaming.util import rddToFileName, RDDFunction, RDDFunction2
from pyspark.rdd import portable_hash
from pyspark.streaming.duration import Seconds

from py4j.java_collections import ListConverter, MapConverter

__all__ = ["DStream"]

Expand All @@ -42,7 +36,6 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
self._jrdd_deserializer = jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self._partitionFunc = None

def context(self):
"""
Expand Down Expand Up @@ -159,7 +152,7 @@ def foreachRDD(self, func):
This is an output operator, so this DStream will be registered as an output
stream and there materialized.
"""
jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, t), self._jrdd_deserializer)
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer)
self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc)

def pyprint(self):
Expand Down Expand Up @@ -306,19 +299,19 @@ def get_output(rdd, time):
return result

def transform(self, func):
return TransformedRDD(self, lambda a, b, t: func(a), cache=True)

def transformWith(self, func, other):
return TransformedRDD(self, lambda a, b, t: func(a, b), other)
return TransformedRDD(self, lambda a, t: func(a), True)

def transformWithTime(self, func):
return TransformedRDD(self, lambda a, b, t: func(a, t))
return TransformedRDD(self, func, False)

def transformWith(self, func, other, keepSerializer=False):
return Transformed2RDD(self, lambda a, b, t: func(a, b), other, keepSerializer)

def repartitions(self, numPartitions):
return self.transform(lambda rdd: rdd.repartition(numPartitions))

def union(self, other):
return self.transformWith(lambda a, b: a.union(b), other)
return self.transformWith(lambda a, b: a.union(b), other, True)

def cogroup(self, other):
return self.transformWith(lambda a, b: a.cogroup(b), other)
Expand All @@ -329,32 +322,79 @@ def leftOuterJoin(self, other):
def rightOuterJoin(self, other):
return self.transformWith(lambda a, b: a.rightOuterJoin(b), other)

def slice(self, fromTime, toTime):
jrdds = self._jdstream.slice(fromTime._jtime, toTime._jtime)
# FIXME: serializer
return [RDD(jrdd, self.ctx, self.ctx.serializer) for jrdd in jrdds]
def _jtime(self, milliseconds):
return self.ctx._jvm.Time(milliseconds)

def slice(self, begin, end):
jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds]

def window(self, windowDuration, slideDuration=None):
d = Seconds(windowDuration)
if slideDuration is None:
return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer)
s = Seconds(slideDuration)
return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer)

def reduceByWindow(self, reduceFunc, inReduceFunc, windowDuration, slideDuration):
pass

def countByWindow(self, window, slide):
pass

def countByValueAndWindow(self, window, slide, numPartitions=None):
pass

def groupByKeyAndWindow(self, window, slide, numPartitions=None):
pass

def reduceByKeyAndWindow(self, reduceFunc, inReduceFunc, window, slide, numPartitions=None):
pass

def updateStateByKey(self, updateFunc):
# FIXME: convert updateFunc to java JFunction2
jFunc = updateFunc
return self._jdstream.updateStateByKey(jFunc)


# Window Operations
# TODO: implement window
# TODO: implement groupByKeyAndWindow
# TODO: implement reduceByKeyAndWindow
# TODO: implement countByValueAndWindow
# TODO: implement countByWindow
# TODO: implement reduceByWindow
class TransformedRDD(DStream):
def __init__(self, prev, func, reuse=False):
ssc = prev._ssc
self._ssc = ssc
self.ctx = ssc._sc
self._jrdd_deserializer = self.ctx.serializer
self.is_cached = False
self.is_checkpointed = False

if isinstance(prev, TransformedRDD) and not prev.is_cached and not prev.is_checkpointed:
prev_func = prev.func
old_func = func
func = lambda rdd, t: old_func(prev_func(rdd, t), t)
reuse = reuse and prev.reuse
prev = prev.prev

self.prev = prev
self.func = func
self.reuse = reuse
self._jdstream_val = None

class TransformedRDD(DStream):
# TODO: better name for cache
def __init__(self, prev, func, other=None, cache=False):
# TODO: combine transformed RDD
@property
def _jdstream(self):
if self._jdstream_val is not None:
return self._jdstream_val

jfunc = RDDFunction(self.ctx, self.func, self.prev._jrdd_deserializer)
jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(),
jfunc, self.reuse).asJavaDStream()
self._jdstream_val = jdstream
return jdstream


class Transformed2RDD(DStream):
def __init__(self, prev, func, other, keepSerializer=False):
ssc = prev._ssc
t = RDDFunction(ssc._sc, func, prev._jrdd_deserializer)
jdstream = ssc._jvm.PythonTransformedDStream(prev._jdstream.dstream(),
other and other._jdstream, t, cache)
DStream.__init__(self, jdstream.asJavaDStream(), ssc, ssc._sc.serializer)
jfunc = RDDFunction2(ssc._sc, func, prev._jrdd_deserializer)
jdstream = ssc._jvm.PythonTransformed2DStream(prev._jdstream.dstream(),
other._jdstream.dstream(), jfunc)
jrdd_serializer = prev._jrdd_deserializer if keepSerializer else ssc._sc.serializer
DStream.__init__(self, jdstream.asJavaDStream(), ssc, jrdd_serializer)
26 changes: 26 additions & 0 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,32 @@ def add(a, b):
[("a", "11"), ("b", "1"), ("", "111")]]
self._test_func(input, func, expected, sort=True)

def test_union(self):
input1 = [range(3), range(5), range(1)]
input2 = [range(3, 6), range(5, 6), range(1, 6)]

d1 = self.ssc._makeStream(input1)
d2 = self.ssc._makeStream(input2)
d = d1.union(d2)
result = d.collect()
expected = [range(6), range(6), range(6)]

self.ssc.start()
start_time = time.time()
# Loop until get the expected the number of the result from the stream.
while True:
current_time = time.time()
# Check time out.
if (current_time - start_time) > self.timeout * 2:
break
# StreamingContext.awaitTermination is not used to wait because
# if py4j server is called every 50 milliseconds, it gets an error.
time.sleep(0.05)
# Check if the output is the same length of expected output.
if len(expected) == len(result):
break
self.assertEqual(expected, result)

def _sort_result_based_on_key(self, outputs):
"""Sort the list base onf first value."""
for output in outputs:
Expand Down
31 changes: 29 additions & 2 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,36 @@ def __init__(self, ctx, func, jrdd_deserializer):
self.func = func
self.deserializer = jrdd_deserializer

def call(self, jrdd, jrdd2, milliseconds):
def call(self, jrdd, milliseconds):
try:
rdd = RDD(jrdd, self.ctx, self.deserializer)
r = self.func(rdd, milliseconds)
if r:
return r._jrdd
except:
import traceback
traceback.print_exc()

def __repr__(self):
return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func))

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']


class RDDFunction2(object):
"""
This class is for py4j callback. This class is related with
org.apache.spark.streaming.api.python.PythonRDDFunction2.
"""
def __init__(self, ctx, func, jrdd_deserializer):
self.ctx = ctx
self.func = func
self.deserializer = jrdd_deserializer

def call(self, jrdd, jrdd2, milliseconds):
try:
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else None
other = RDD(jrdd2, self.ctx, self.deserializer) if jrdd2 else None
r = self.func(rdd, other, milliseconds)
if r:
Expand All @@ -43,7 +70,7 @@ def __repr__(self):
return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func))

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2']


def rddToFileName(prefix, suffix, time):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,77 +28,91 @@ import org.apache.spark.streaming.api.java._


/**
* Interface for Python callback function
* Interface for Python callback function with two arguments
*/
trait PythonRDDFunction {
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
def call(rdd: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
}

/**
* Interface for Python callback function with three arguments
*/
trait PythonRDDFunction2 {
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
}

/**
* Transformed DStream in Python.
*
* If the result RDD is PythonRDD, then it will cache it as an template for future use,
* this can reduce the Python callbacks.
*
* @param parent
* @param parent2
* @param func
* @param cache
*/
class PythonTransformedDStream (parent: DStream[_], parent2: DStream[_], func: PythonRDDFunction,
cache: Boolean = false)
class PythonTransformedDStream (parent: DStream[_], func: PythonRDDFunction,
var reuse: Boolean = false)
extends DStream[Array[Byte]] (parent.ssc) {

var lastResult: PythonRDD = _

override def dependencies = {
if (parent2 == null) {
List(parent)
} else {
List(parent, parent2)
}
}
override def dependencies = List(parent)

override def slideDuration: Duration = parent.slideDuration

override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val rdd1 = parent.getOrCompute(validTime).getOrElse(null)
val rdd2 = if (parent2 != null) parent2.getOrCompute(validTime).getOrElse(null) else null

val r = if (rdd2 != null) {
func.call(JavaRDD.fromRDD(rdd1), JavaRDD.fromRDD(rdd2), validTime.milliseconds)
} else if (cache && lastResult != null) {
lastResult.copyTo(rdd1).asJavaRDD
if (reuse && lastResult != null) {
Some(lastResult.copyTo(rdd1))
} else {
func.call(JavaRDD.fromRDD(rdd1), null, validTime.milliseconds)
}
if (r != null) {
if (lastResult == null && r.isInstanceOf[PythonRDD]) {
lastResult = r.asInstanceOf[PythonRDD]
val r = func.call(JavaRDD.fromRDD(rdd1), validTime.milliseconds).rdd
if (reuse && lastResult == null) {
r match {
case rdd: PythonRDD =>
if (rdd.parent(0) == rdd1) {
// only one PythonRDD
lastResult = rdd
} else {
// may have multiple stages
reuse = false
}
}
}
Some(r)
} else {
None
}
}

val asJavaDStream = JavaDStream.fromDStream(this)
}

/**
* Transformed from two DStreams in Python.
*/
class PythonTransformed2DStream (parent: DStream[_], parent2: DStream[_], func: PythonRDDFunction2)
extends DStream[Array[Byte]] (parent.ssc) {

override def dependencies = List(parent, parent2)

override def slideDuration: Duration = parent.slideDuration

override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
def resultRdd(stream: DStream[_]): JavaRDD[_] = stream.getOrCompute(validTime) match {
case Some(rdd) => JavaRDD.fromRDD(rdd)
case None => null
}
Some(func.call(resultRdd(parent), resultRdd(parent2), validTime.milliseconds))
}

val asJavaDStream = JavaDStream.fromDStream(this)
}

/**
* This is used for foreachRDD() in Python
* @param prev
* @param foreachFunction
*/
class PythonForeachDStream(
prev: DStream[Array[Byte]],
foreachFunction: PythonRDDFunction
) extends ForEachDStream[Array[Byte]](
prev,
(rdd: RDD[Array[Byte]], time: Time) => {
foreachFunction.call(rdd.toJavaRDD(), null, time.milliseconds)
foreachFunction.call(rdd.toJavaRDD(), time.milliseconds)
}
) {

Expand Down

0 comments on commit eec401e

Please sign in to comment.