Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 3, 2014
1 parent 37fe06f commit e108ec1
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Try, Success, Failure}

import net.razorvine.pickle.{Pickler, Unpickler}

Expand All @@ -52,12 +50,6 @@ private[spark] class PythonRDD(
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {

// create a new PythonRDD with same Python setting but different parent.
def copyTo(rdd: RDD[_]): PythonRDD = {
new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning,
pythonExec, broadcastVars, accumulator)
}

val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,8 @@ def sum(self):
>>> sc.parallelize([1.0, 2.0, 3.0]).sum()
6.0
"""
if not self.getNumPartitions():
return 0 # empty RDD can not been reduced
return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)

def count(self):
Expand Down
38 changes: 22 additions & 16 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,18 @@ class StreamingContext(object):
"""
_transformerSerializer = None

def __init__(self, sparkContext, duration=None, jssc=None):
def __init__(self, sparkContext, batchDuration=None, jssc=None):
"""
Create a new StreamingContext.
@param sparkContext: L{SparkContext} object.
@param duration: number of seconds.
@param batchDuration: the time interval (in seconds) at which streaming
data will be divided into batches
"""

self._sc = sparkContext
self._jvm = self._sc._jvm
self._jssc = jssc or self._initialize_context(self._sc, duration)
self._jssc = jssc or self._initialize_context(self._sc, batchDuration)

def _initialize_context(self, sc, duration):
self._ensure_initialized()
Expand Down Expand Up @@ -134,26 +135,27 @@ def _ensure_initialized(cls):
SparkContext._active_spark_context, CloudPickleSerializer(), gw)

@classmethod
def getOrCreate(cls, path, setupFunc):
def getOrCreate(cls, checkpointPath, setupFunc):
"""
Get the StreamingContext from checkpoint file at `path`, or setup
it by `setupFunc`.
Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
will be used to create a JavaStreamingContext.
:param path: directory of checkpoint
:param setupFunc: a function used to create StreamingContext and
setup DStreams.
:return: a StreamingContext
@param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
@param setupFunc Function to create a new JavaStreamingContext and setup DStreams
"""
if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path):
# TODO: support checkpoint in HDFS
if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
ssc = setupFunc()
ssc.checkpoint(path)
ssc.checkpoint(checkpointPath)
return ssc

cls._ensure_initialized()
gw = SparkContext._gateway

try:
jssc = gw.jvm.JavaStreamingContext(path)
jssc = gw.jvm.JavaStreamingContext(checkpointPath)
except Exception:
print >>sys.stderr, "failed to load StreamingContext from checkpoint"
raise
Expand Down Expand Up @@ -249,12 +251,12 @@ def textFileStream(self, directory):
"""
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())

def _check_serialzers(self, rdds):
def _check_serializers(self, rdds):
# make sure they have same serializer
if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
for i in range(len(rdds)):
# reset them to sc.serializer
rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True)
rdds[i] = rdds[i]._reserialize()

def queueStream(self, rdds, oneAtATime=True, default=None):
"""
Expand All @@ -275,7 +277,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None):

if rdds and not isinstance(rdds[0], RDD):
rdds = [self._sc.parallelize(input) for input in rdds]
self._check_serialzers(rdds)
self._check_serializers(rdds)

jrdds = ListConverter().convert([r._jrdd for r in rdds],
SparkContext._gateway._gateway_client)
Expand Down Expand Up @@ -313,6 +315,10 @@ def union(self, *dstreams):
raise ValueError("should have at least one DStream to union")
if len(dstreams) == 1:
return dstreams[0]
if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
raise ValueError("All DStreams should have same serializer")
if len(set(s._slideDuration for s in dstreams)) > 1:
raise ValueError("All DStreams should have same slide duration")
first = dstreams[0]
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
SparkContext._gateway._gateway_client)
Expand Down
Loading

0 comments on commit e108ec1

Please sign in to comment.