Skip to content

Commit

Permalink
fix zip with serializers which have different batch sizes.
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Aug 11, 2014
1 parent ba28a8f commit a4aafda
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
25 changes: 25 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,31 @@ def zip(self, other):
>>> x.zip(y).collect()
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
"""
if self.getNumPartitions() != other.getNumPartitions():
raise ValueError("Can only zip with RDD which has the same number of partitions")

def get_batch_size(ser):
if isinstance(ser, BatchedSerializer):
return ser.batchSize
return 0

def batch_as(rdd, batchSize):
ser = rdd._jrdd_deserializer
if isinstance(ser, BatchedSerializer):
ser = ser.serializer
return rdd._reserialize(BatchedSerializer(ser, batchSize))

my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
if my_batch != other_batch:
# use the greatest batchSize to batch the other one.
if my_batch > other_batch:
other = batch_as(other, my_batch)
else:
self = batch_as(self, other_batch)

# There will be an Exception in JVM if there are different number
# of items in each partitions.
pairRDD = self._jrdd.zip(other._jrdd)
deserializer = PairDeserializer(self._jrdd_deserializer,
other._jrdd_deserializer)
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def __init__(self, key_ser, val_ser):

def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
if len(keys) != len(vals):
raise ValueError("Can not deserialize RDD with different number of items"
" in pair: (%d, %d)" % (len(keys), len(vals)))
for pair in izip(keys, vals):
yield pair

Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger

_have_scipy = False
Expand Down Expand Up @@ -318,6 +318,14 @@ def test_namedtuple_in_rdd(self):
theDoes = self.sc.parallelize([jon, jane])
self.assertEquals([jon, jane], theDoes.collect())

def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))
b = self.sc.parallelize(range(100, 105))
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])


class TestIO(PySparkTestCase):

Expand Down

0 comments on commit a4aafda

Please sign in to comment.