Skip to content

Commit

Permalink
support ssc.transform()
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 29, 2014
1 parent b983f0f commit 98ac6c2
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 48 deletions.
18 changes: 14 additions & 4 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.dstream import DStream
from pyspark.streaming.util import RDDFunction

from py4j.java_collections import ListConverter
from py4j.java_gateway import java_import
Expand Down Expand Up @@ -212,11 +213,20 @@ def queueStream(self, queue, oneAtATime=True, default=None):

def transform(self, dstreams, transformFunc):
"""
Create a new DStream in which each RDD is generated by applying a function on RDDs of
the DStreams. The order of the JavaRDDs in the transform function parameter will be the
same as the order of corresponding DStreams in the list.
Create a new DStream in which each RDD is generated by applying
a function on RDDs of the DStreams. The order of the JavaRDDs in
the transform function parameter will be the same as the order
of corresponding DStreams in the list.
"""
# TODO
jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
SparkContext._gateway._gateway_client)
# change the final serializer to sc.serializer
jfunc = RDDFunction(self._sc,
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
*[d._jrdd_deserializer for d in dstreams])

jdstream = self._jvm.PythonDStream.callTransform(self._jssc, jdstreams, jfunc)
return DStream(jdstream, self, self._sc.serializer)

def union(self, *dstreams):
"""
Expand Down
36 changes: 18 additions & 18 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc))

def foreach(self, func):
return self.foreachRDD(lambda rdd, _: rdd.foreach(func))
return self.foreachRDD(lambda _, rdd: rdd.foreach(func))

def foreachRDD(self, func):
"""
Expand All @@ -142,7 +142,7 @@ def foreachRDD(self, func):
This is an output operator, so this DStream will be registered as an output
stream and there materialized.
"""
jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer)
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer)
api = self._ssc._jvm.PythonDStream
api.callForeachRDD(self._jdstream, jfunc)

Expand All @@ -151,10 +151,10 @@ def pprint(self):
Print the first ten elements of each RDD generated in this DStream. This is an output
operator, so this DStream will be registered as an output stream and there materialized.
"""
def takeAndPrint(rdd, time):
def takeAndPrint(timestamp, rdd):
taken = rdd.take(11)
print "-------------------------------------------"
print "Time: %s" % datetime.fromtimestamp(time / 1000.0)
print "Time: %s" % datetime.fromtimestamp(timestamp / 1000.0)
print "-------------------------------------------"
for record in taken[:10]:
print record
Expand All @@ -176,15 +176,15 @@ def take(self, n):
"""
rdds = []

def take(rdd, _):
if rdd:
def take(_, rdd):
if rdd and len(rdds) < n:
rdds.append(rdd)
if len(rdds) == n:
# FIXME: NPE in JVM
self._ssc.stop(False)
self.foreachRDD(take)

self._ssc.start()
self._ssc.awaitTermination()
while len(rdds) < n:
time.sleep(0.01)
self._ssc.stop(False, True)
return rdds

def collect(self):
Expand All @@ -195,7 +195,7 @@ def collect(self):
"""
result = []

def get_output(rdd, time):
def get_output(_, rdd):
r = rdd.collect()
result.append(r)
self.foreachRDD(get_output)
Expand Down Expand Up @@ -317,7 +317,7 @@ def transform(self, func):
Return a new DStream in which each RDD is generated by applying a function
on each RDD of 'this' DStream.
"""
return TransformedDStream(self, lambda a, t: func(a), True)
return TransformedDStream(self, lambda t, a: func(a), True)

def transformWithTime(self, func):
"""
Expand All @@ -331,7 +331,7 @@ def transformWith(self, func, other, keepSerializer=False):
Return a new DStream in which each RDD is generated by applying a function
on each RDD of 'this' DStream and 'other' DStream.
"""
jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, b), self._jrdd_deserializer)
jfunc = RDDFunction(self.ctx, lambda t, a, b: func(a, b), self._jrdd_deserializer)
dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
other._jdstream.dstream(), jfunc)
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer
Expand Down Expand Up @@ -549,14 +549,14 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None
self._check_window(windowDuration, slideDuration)
reduced = self.reduceByKey(func)

def reduceFunc(a, b, t):
def reduceFunc(t, a, b):
b = b.reduceByKey(func, numPartitions)
r = a.union(b).reduceByKey(func, numPartitions) if a else b
if filterFunc:
r = r.filter(filterFunc)
return r

def invReduceFunc(a, b, t):
def invReduceFunc(t, a, b):
b = b.reduceByKey(func, numPartitions)
joined = a.leftOuterJoin(b, numPartitions)
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
Expand All @@ -582,7 +582,7 @@ def updateStateByKey(self, updateFunc, numPartitions=None):
@param updateFunc State update function ([(k, vs, s)] -> [(k, s)]).
If `s` is None, then `k` will be eliminated.
"""
def reduceFunc(a, b, t):
def reduceFunc(t, a, b):
if a is None:
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
else:
Expand Down Expand Up @@ -610,7 +610,7 @@ def __init__(self, prev, func, reuse=False):
not prev.is_cached and not prev.is_checkpointed):
prev_func = prev.func
old_func = func
func = lambda rdd, t: old_func(prev_func(rdd, t), t)
func = lambda t, rdd: old_func(t, prev_func(t, rdd))
reuse = reuse and prev.reuse
prev = prev.prev

Expand All @@ -625,7 +625,7 @@ def _jdstream(self):
return self._jdstream_val

func = self.func
jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self.prev._jrdd_deserializer)
jfunc = RDDFunction(self.ctx, func, self.prev._jrdd_deserializer)
jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(),
jfunc, self.reuse).asJavaDStream()
self._jdstream_val = jdstream
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,19 @@ def test_union(self):
expected = [i * 2 for i in input]
self.assertEqual(expected, result[:3])

def test_transform(self):
dstream1 = self.ssc.queueStream([[1]])
dstream2 = self.ssc.queueStream([[2]])
dstream3 = self.ssc.queueStream([[3]])

def func(rdds):
rdd1, rdd2, rdd3 = rdds
return rdd2.union(rdd3).union(rdd1)

dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)

self.assertEqual([2, 3, 1], dstream.first().collect())


if __name__ == "__main__":
unittest.main()
26 changes: 15 additions & 11 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,33 @@ class RDDFunction(object):
"""
This class is for py4j callback.
"""
def __init__(self, ctx, func, deserializer, deserializer2=None):
def __init__(self, ctx, func, *deserializers):
self.ctx = ctx
self.func = func
self.deserializer = deserializer
self.deserializer2 = deserializer2 or deserializer
self.deserializers = deserializers
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
if emptyRDD is None:
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
self.emptyRDD = emptyRDD

def call(self, jrdd, jrdd2, milliseconds):
def call(self, milliseconds, jrdds):
try:
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
if emptyRDD is None:
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
# extend deserializers with the first one
sers = self.deserializers
if len(sers) < len(jrdds):
sers += (sers[0],) * (len(jrdds) - len(sers))

rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD
other = RDD(jrdd2, self.ctx, self.deserializer2) if jrdd2 else emptyRDD
r = self.func(rdd, other, milliseconds)
rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD
for jrdd, ser in zip(jrdds, sers)]
r = self.func(milliseconds, *rdds)
if r:
return r._jrdd
except Exception:
import traceback
traceback.print_exc()

def __repr__(self):
return "RDDFunction2(%s)" % (str(self.func))
return "RDDFunction(%s)" % (str(self.func))

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ class StreamingContext private[streaming] (
dstreams: Seq[DStream[_]],
transformFunc: (Seq[RDD[_]], Time) => RDD[T]
): DStream[T] = {
new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc))
new TransformedDStream[T](dstreams, (transformFunc))
}

/** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,32 @@

package org.apache.spark.streaming.api.python

import java.util.{ArrayList => JArrayList}
import java.util.{ArrayList => JArrayList, List => JList}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.api.java._
import org.apache.spark.api.java.function.{Function2 => JFunction2}
import org.apache.spark.api.python._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Interval, Duration, Time}
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.api.java._


/**
* Interface for Python callback function with three arguments
*/
trait PythonRDDFunction {
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
}

class RDDFunction(pfunc: PythonRDDFunction) extends Serializable {

def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
apply(rdd, None, time)
}
/**
* Wrapper for PythonRDDFunction
*/
class RDDFunction(pfunc: PythonRDDFunction)
extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable {

def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = {
if (rdd.isDefined) {
Expand All @@ -50,14 +52,25 @@ class RDDFunction(pfunc: PythonRDDFunction) extends Serializable {
}
}

def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds)
if (r != null) {
Some(r.rdd)
def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = {
if (jrdd != null) {
Some(jrdd.rdd)
} else {
None
}
}

def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
some(pfunc.call(time.milliseconds, List(wrapRDD(rdd)).asJava))
}

def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
some(pfunc.call(time.milliseconds, List(wrapRDD(rdd), wrapRDD(rdd2)).asJava))
}

def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
pfunc.call(time.milliseconds, rdds)
}
}

private[python]
Expand All @@ -74,8 +87,16 @@ private[spark] object PythonDStream {

// helper function for DStream.foreachRDD(),
// cannot be `foreachRDD`, it will confusing py4j
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction): Unit = {
jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null, time.milliseconds))
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction){
val func = new RDDFunction(pyfunc)
jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
}

// helper function for ssc.transform()
def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], pyfunc: PythonRDDFunction)
:JavaDStream[Array[Byte]] = {
val func = new RDDFunction(pyfunc)
ssc.transform(jdsteams, func)
}

// convert list of RDD into queue of RDDs, for ssc.queueStream()
Expand Down

0 comments on commit 98ac6c2

Please sign in to comment.