Skip to content

Commit

Permalink
change _first(), _take(), _collect() as private API
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 30, 2014
1 parent 19797f9 commit 338580a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
8 changes: 4 additions & 4 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ def takeAndPrint(timestamp, rdd):

self.foreachRDD(takeAndPrint)

def first(self):
def _first(self):
"""
Return the first RDD in the stream.
"""
return self.take(1)[0]
return self._take(1)[0]

def take(self, n):
def _take(self, n):
"""
Return the first `n` RDDs in the stream (will start and stop).
"""
Expand All @@ -188,7 +188,7 @@ def take(_, rdd):
self._ssc.stop(False, True)
return results

def collect(self):
def _collect(self):
"""
Collect each RDDs into the returned list.
Expand Down
23 changes: 14 additions & 9 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None):
else:
stream = func(input_stream)

result = stream.collect()
result = stream._collect()
self.ssc.start()

start_time = time.time()
Expand Down Expand Up @@ -86,12 +86,12 @@ class TestBasicOperations(PySparkStreamingTestCase):
def test_take(self):
input = [range(i) for i in range(3)]
dstream = self.ssc.queueStream(input)
self.assertEqual([0, 0, 1], dstream.take(3))
self.assertEqual([0, 0, 1], dstream._take(3))

def test_first(self):
input = [range(10)]
dstream = self.ssc.queueStream(input)
self.assertEqual(0, dstream.first())
self.assertEqual(0, dstream._first())

def test_map(self):
"""Basic operation test for DStream.map."""
Expand Down Expand Up @@ -415,26 +415,31 @@ def _addInputStream(self):
# Make sure each length of input is over 3
inputs = map(lambda x: range(1, x), range(5, 101))
stream = self.ssc.queueStream(inputs)
stream.collect()
stream._collect()

def test_queueStream(self):
input = [range(i) for i in range(3)]
dstream = self.ssc.queueStream(input)
result = dstream.collect()
result = dstream._collect()
self.ssc.start()
time.sleep(1)
self.assertEqual(input, result[:3])

# TODO: test textFileStream
# TODO: fix this test
# def test_textFileStream(self):
# input = [range(i) for i in range(3)]
# dstream = self.ssc.queueStream(input)
# d = os.path.join(tempfile.gettempdir(), str(id(self)))
# if not os.path.exists(d):
# os.makedirs(d)
# dstream.saveAsTextFiles(os.path.join(d, 'test'))
# self.ssc.start()
# time.sleep(1)
# self.ssc.stop(False, True)
#
# self.ssc = StreamingContext(self.sc, self.batachDuration)
# dstream2 = self.ssc.textFileStream(d)
# result = dstream2.collect()
# result = dstream2._collect()
# self.ssc.start()
# time.sleep(2)
# self.assertEqual(input, result[:3])
Expand All @@ -444,7 +449,7 @@ def test_union(self):
dstream = self.ssc.queueStream(input)
dstream2 = self.ssc.queueStream(input)
dstream3 = self.ssc.union(dstream, dstream2)
result = dstream3.collect()
result = dstream3._collect()
self.ssc.start()
time.sleep(1)
expected = [i * 2 for i in input]
Expand All @@ -461,7 +466,7 @@ def func(rdds):

dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)

self.assertEqual([2, 3, 1], dstream.take(3))
self.assertEqual([2, 3, 1], dstream._take(3))


if __name__ == "__main__":
Expand Down

0 comments on commit 338580a

Please sign in to comment.