Skip to content

Commit

Permalink
rafactor of foreachRDD()
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 28, 2014
1 parent 7001b51 commit fce0ef5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
3 changes: 2 additions & 1 deletion python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def foreachRDD(self, func):
stream and there materialized.
"""
jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer)
self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc)
api = self._ssc._jvm.PythonDStream
api.callForeachRDD(self._jdstream, jfunc)

def pprint(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package org.apache.spark.streaming.api.python
import java.util.{ArrayList => JArrayList}
import scala.collection.JavaConversions._

import org.apache.spark.rdd.RDD
import org.apache.spark.api.java._
import org.apache.spark.api.java.function.{Function2 => JFunction2}
import org.apache.spark.api.python._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Interval, Duration, Time}
import org.apache.spark.streaming.dstream._
Expand All @@ -35,19 +36,22 @@ trait PythonRDDFunction {
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
}

class RDDFunction(pfunc: PythonRDDFunction) {
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
val jrdd = if (rdd.isDefined) {
class RDDFunction(pfunc: PythonRDDFunction) extends Serializable {

def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
apply(rdd, None, time)
}

def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = {
if (rdd.isDefined) {
JavaRDD.fromRDD(rdd.get)
} else {
null
}
val jrdd2 = if (rdd2.isDefined) {
JavaRDD.fromRDD(rdd2.get)
} else {
null
}
val r = pfunc.call(jrdd, jrdd2, time.milliseconds)
}

def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds)
if (r != null) {
Some(r.rdd)
} else {
Expand All @@ -66,7 +70,13 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p
val asJavaDStream = JavaDStream.fromDStream(this)
}

object PythonDStream {
private[spark] object PythonDStream {

// helper function for DStream.foreachRDD(),
// cannot be `foreachRDD`, it will confusing py4j
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction): Unit = {
jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null, time.milliseconds))
}

// convert list of RDD into queue of RDDs, for ssc.queueStream()
def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
Expand Down Expand Up @@ -97,7 +107,7 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python
if (reuse && lastResult != null) {
Some(lastResult.copyTo(rdd1.get))
} else {
val r = func(rdd1, None, validTime)
val r = func(rdd1, validTime)
if (reuse && r.isDefined && lastResult == null) {
r.get match {
case rdd: PythonRDD =>
Expand Down Expand Up @@ -206,8 +216,9 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
// Get the RDD of the reduced value of the previous window
val previousWindowRDD = getOrCompute(previousWindow.endTime)

// for small window, reduce once will be better than twice
if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) {
// subtle the values from old RDDs
// subtract the values from old RDDs
val oldRDDs =
parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
val subbed = if (oldRDDs.size > 0) {
Expand Down Expand Up @@ -236,22 +247,4 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
}
}
}
}

/**
* This is used for foreachRDD() in Python
*/
class PythonForeachDStream(
prev: DStream[Array[Byte]],
foreachFunction: PythonRDDFunction
) extends ForEachDStream[Array[Byte]](
prev,
(rdd: RDD[Array[Byte]], time: Time) => {
if (rdd != null) {
foreachFunction.call(rdd, null, time.milliseconds)
}
}
) {

this.register()
}

0 comments on commit fce0ef5

Please sign in to comment.