Skip to content

Commit

Permalink
remove waste duplicated code
Browse files Browse the repository at this point in the history
  • Loading branch information
giwa committed Aug 15, 2014
1 parent a14c7e1 commit e3033fc
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 62 deletions.
43 changes: 1 addition & 42 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,48 +130,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
# Stop Callback server
SparkContext._gateway.shutdown()

def checkpoint(self, directory):
"""
Not tested
"""
self._jssc.checkpoint(directory)

def _testInputStream(self, test_inputs, numSlices=None):
"""
Generate multiple files to make "stream" in Scala side for test.
Scala chooses one of the files and generates RDD using PythonRDD.readRDDFromFile.
QueStream maybe good way to implement this function
"""
numSlices = numSlices or self._sc.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().

tempFiles = list()
for test_input in test_inputs:
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)

# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(test_input):
test_input = list(test_input) # Make it a list so we can compute its length
batchSize = min(len(test_input) // numSlices, self._sc._batchSize)
if batchSize > 1:
serializer = BatchedSerializer(self._sc._unbatched_serializer,
batchSize)
else:
serializer = self._sc._unbatched_serializer
serializer.dump_stream(test_input, tempFile)
tempFile.close()
tempFiles.append(tempFile.name)

jtempFiles = ListConverter().convert(tempFiles, SparkContext._gateway._gateway_client)
jinput_stream = self._jvm.PythonTestInputStream(self._jssc,
jtempFiles,
numSlices).asJavaDStream()
return DStream(jinput_stream, self, BatchedSerializer(PickleSerializer()))

def _testInputStream2(self, test_inputs, numSlices=None):
"""
This is inpired by QueStream implementation. Give list of RDD and generate DStream
which contain the RDD.
Expand All @@ -184,7 +143,7 @@ def _testInputStream2(self, test_inputs, numSlices=None):
test_rdd_deserializers.append(test_rdd._jrdd_deserializer)

jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client)
jinput_stream = self._jvm.PythonTestInputStream2(self._jssc, jtest_rdds).asJavaDStream()
jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtest_rdds).asJavaDStream()

dstream = DStream(jinput_stream, self, test_rdd_deserializers[0])
return dstream
Expand Down
75 changes: 55 additions & 20 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

from collections import defaultdict
from itertools import chain, ifilter, imap
import time
import operator

from pyspark.serializers import NoOpSerializer,\
BatchedSerializer, CloudPickleSerializer, pack_long
from pyspark.rdd import _JavaStackTrace
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable

from py4j.java_collections import ListConverter, MapConverter

Expand All @@ -35,6 +36,8 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
self._ssc = ssc
self.ctx = ssc._sc
self._jrdd_deserializer = jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False

def context(self):
"""
Expand Down Expand Up @@ -234,8 +237,6 @@ def takeAndPrint(rdd, time):
taken = rdd.take(11)
print "-------------------------------------------"
print "Time: %s" % (str(time))
print rdd.glom().collect()
print "-------------------------------------------"
print "-------------------------------------------"
for record in taken[:10]:
print record
Expand Down Expand Up @@ -290,32 +291,65 @@ def get_output(rdd, time):

self.foreachRDD(get_output)

def _test_switch_dserializer(self, serializer_que):
def cache(self):
"""
Persist this DStream with the default storage level (C{MEMORY_ONLY_SER}).
"""
self.is_cached = True
self.persist(StorageLevel.MEMORY_ONLY_SER)
return self

def persist(self, storageLevel):
"""
Set this DStream's storage level to persist its values across operations
after the first time it is computed. This can only be used to assign
a new storage level if the DStream does not have a storage level set yet.
"""
self.is_cached = True
javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
self._jdstream.persist(javaStorageLevel)
return self

def checkpoint(self, interval):
"""
Deserializer is dynamically changed based on numSlice and the number of
input. This function choose deserializer. Currently this is just FIFO.
Mark this DStream for checkpointing. It will be saved to a file inside the
checkpoint directory set with L{SparkContext.setCheckpointDir()}
I am not sure this part in DStream
and
all references to its parent RDDs will be removed. This function must
be called before any job has been executed on this RDD. It is strongly
recommended that this RDD is persisted in memory, otherwise saving it
on a file will require recomputation.
interval must be pysprak.streaming.duration
"""

jrdd_deserializer = self._jrdd_deserializer
self.is_checkpointed = True
self._jdstream.checkpoint(interval)
return self

def groupByKey(self, numPartitions=None):
def createCombiner(x):
return [x]

def switch(rdd, jtime):
try:
print serializer_que
jrdd_deserializer = serializer_que.pop(0)
print jrdd_deserializer
except Exception as e:
print e
def mergeValue(xs, x):
xs.append(x)
return xs

self.foreachRDD(switch)
def mergeCombiners(a, b):
a.extend(b)
return a

return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numPartitions).mapValues(lambda x: ResultIterable(x))


# TODO: implement groupByKey
# TODO: implement saveAsTextFile

# Following operation has dependency to transform
# TODO: impelment union
# TODO: implement cache
# TODO: implement persist
# TODO: implement repertitions
# TODO: implement saveAsTextFile
# TODO: implement cogroup
# TODO: implement join
# TODO: implement countByValue
Expand All @@ -342,6 +376,7 @@ def pipeline_func(split, iterator):
self._prev_jdstream = prev._prev_jdstream # maintain the pipeline
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self._ssc = prev._ssc
self.ctx = prev.ctx
self.prev = prev
Expand Down Expand Up @@ -378,4 +413,4 @@ def _jdstream(self):
return self._jdstream_val

def _is_pipelinable(self):
return not self.is_cached
return not (self.is_cached or self.is_checkpointed)

0 comments on commit e3033fc

Please sign in to comment.