Skip to content

Commit

Permalink
[SPARK-3681] [SQL] [PySpark] fix serialization of List and Map in Sch…
Browse files Browse the repository at this point in the history
…emaRDD

Currently, the schema of object in ArrayType or MapType is attached lazily, it will have better performance but introduce issues while serialization or accessing nested objects.

This patch will apply schema to the objects of ArrayType or MapType immediately when accessing them, will be a little bit slower, but much robust.

Author: Davies Liu <davies.liu@gmail.com>

Closes #2526 from davies/nested and squashes the following commits:

2399ae5 [Davies Liu] fix serialization of List and Map in SchemaRDD
  • Loading branch information
davies authored and marmbrus committed Sep 27, 2014
1 parent f0c7e19 commit 0d8cdf0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
40 changes: 13 additions & 27 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,43 +838,29 @@ def _create_cls(dataType):
>>> obj = _create_cls(schema)(row)
>>> pickle.loads(pickle.dumps(obj))
Row(a=[1], b={'key': Row(c=1, d=2.0)})
>>> pickle.loads(pickle.dumps(obj.a))
[1]
>>> pickle.loads(pickle.dumps(obj.b))
{'key': Row(c=1, d=2.0)}
"""

if isinstance(dataType, ArrayType):
cls = _create_cls(dataType.elementType)

class List(list):

def __getitem__(self, i):
# create object with datetype
return _create_object(cls, list.__getitem__(self, i))

def __repr__(self):
# call collect __repr__ for nested objects
return "[%s]" % (", ".join(repr(self[i])
for i in range(len(self))))

def __reduce__(self):
return list.__reduce__(self)
def List(l):
if l is None:
return
return [_create_object(cls, v) for v in l]

return List

elif isinstance(dataType, MapType):
vcls = _create_cls(dataType.valueType)

class Dict(dict):

def __getitem__(self, k):
# create object with datetype
return _create_object(vcls, dict.__getitem__(self, k))

def __repr__(self):
# call collect __repr__ for nested objects
return "{%s}" % (", ".join("%r: %r" % (k, self[k])
for k in self))
cls = _create_cls(dataType.valueType)

def __reduce__(self):
return dict.__reduce__(self)
def Dict(d):
if d is None:
return
return dict((k, _create_object(cls, v)) for k, v in d.items())

return Dict

Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,27 @@ def test_apply_schema_to_row(self):
srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
self.assertEqual(10, srdd3.count())

def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
srdd = self.sqlCtx.inferSchema(rdd)
row = srdd.first()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
self.assertEqual("2", row.d["key"].d)

l = srdd.map(lambda x: x.l).first()
self.assertEqual(1, len(l))
self.assertEqual('s', l[0].b)

d = srdd.map(lambda x: x.d).first()
self.assertEqual(1, len(d))
self.assertEqual(1.0, d["key"].c)

row = srdd.map(lambda x: x.d["key"]).first()
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)


class TestIO(PySparkTestCase):

Expand Down

0 comments on commit 0d8cdf0

Please sign in to comment.