Skip to content

Commit

Permalink
fix the number of partitions during window()
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 30, 2014
1 parent 338580a commit 069a94c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
12 changes: 9 additions & 3 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,14 +552,18 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None

def reduceFunc(t, a, b):
b = b.reduceByKey(func, numPartitions)
r = a.union(b).reduceByKey(func, numPartitions) if a else b
# use the average of number of partitions, or it will keep increasing
partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2
r = a.union(b).reduceByKey(func, partitions) if a else b
if filterFunc:
r = r.filter(filterFunc)
return r

def invReduceFunc(t, a, b):
b = b.reduceByKey(func, numPartitions)
joined = a.leftOuterJoin(b, numPartitions)
# use the average of number of partitions, or it will keep increasing
partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2
joined = a.leftOuterJoin(b, partitions)
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)

jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
Expand Down Expand Up @@ -587,7 +591,9 @@ def reduceFunc(t, a, b):
if a is None:
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
else:
g = a.cogroup(b, numPartitions)
# use the average of number of partitions, or it will keep increasing
partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2
g = a.cogroup(b, partitions)
g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None))
state = g.mapPartitions(lambda x: updateFunc(x))
return state.filter(lambda (k, v): v is not None)
Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import unittest
import tempfile

from pyspark.context import SparkContext
from pyspark.context import SparkContext, RDD
from pyspark.streaming.context import StreamingContext


Expand All @@ -46,8 +46,13 @@ def _test_func(self, input, func, expected, sort=False, input2=None):
@param func: wrapped function. This function should return PythonDStream object.
@param expected: expected output for this testcase.
"""
if not isinstance(input[0], RDD):
input = [self.sc.parallelize(d, 1) for d in input]
input_stream = self.ssc.queueStream(input)
if input2 and not isinstance(input2[0], RDD):
input2 = [self.sc.parallelize(d, 1) for d in input2]
input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None

# Apply test function to stream.
if input2:
stream = func(input_stream, input_stream2)
Expand All @@ -63,6 +68,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None):
current_time = time.time()
# Check time out.
if (current_time - start_time) > self.timeout:
print "timeout after", self.timeout
break
# StreamingContext.awaitTermination is not used to wait because
# if py4j server is called every 50 milliseconds, it gets an error.
Expand Down

0 comments on commit 069a94c

Please sign in to comment.