Skip to content

Commit

Permalink
additional fixes for zip anlong with unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
aray committed Dec 2, 2016
1 parent ad43e31 commit 6e3d9d0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
24 changes: 13 additions & 11 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,19 +316,21 @@ def __init__(self, key_ser, val_ser):
self.key_ser = key_ser
self.val_ser = val_ser

def prepare_keys_values(self, stream):
key_stream = self.key_ser._load_stream_without_unbatching(stream)
val_stream = self.val_ser._load_stream_without_unbatching(stream)
for (keys, vals) in zip(key_stream, val_stream):
yield (keys, vals)
def _load_stream_without_unbatching(self, stream):
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
# We must put these batches in lists exactly once and
# in order since they are pulling from the same stream
key_list = list(key_batch)
val_list = list(val_batch)
if len(key_list) != len(val_list):
raise ValueError("Can not deserialize PairRDD with different number of items"
" in batches: (%d, %d)" % (len(key_list), len(val_list)))
yield zip(key_list, val_list)

def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
if len(keys) != len(vals):
raise ValueError("Can not deserialize RDD with different number of items"
" in pair: (%d, %d)" % (len(keys), len(vals)))
for pair in zip(keys, vals):
yield pair
return chain.from_iterable(self._load_stream_without_unbatching(stream))

def __repr__(self):
return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,11 @@ def test_cartesian_chaining(self):
set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
)

self.assertSetEqual(
set(rdd.cartesian(rdd.zip(rdd)).collect()),
set([(x, (y, y)) for x in range(10) for y in range(10)])
)

def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)
Expand Down

0 comments on commit 6e3d9d0

Please sign in to comment.