Skip to content

Commit

Permalink
[SPARK-6886] [PySpark] fix big closure with shuffle
Browse files Browse the repository at this point in the history
Currently, the created broadcast object will have same life cycle as RDD in Python. For multistage jobs, an PythonRDD will be created in JVM and the RDD in Python may be GCed, then the broadcast will be destroyed in JVM before the PythonRDD.

This PR change to use PythonRDD to track the lifecycle of the broadcast object. It also have a refactor about getNumPartitions() to avoid unnecessary creation of PythonRDD, which could be heavy.

cc JoshRosen

Author: Davies Liu <davies@databricks.com>

Closes #5496 from davies/big_closure and squashes the following commits:

9a0ea4c [Davies Liu] fix big closure with shuffle
  • Loading branch information
Davies Liu authored and JoshRosen committed Apr 15, 2015
1 parent 6c5ed8a commit f11288d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
15 changes: 5 additions & 10 deletions python/pyspark/rdd.py
Expand Up @@ -1197,7 +1197,7 @@ def take(self, num):
[91, 92, 93] [91, 92, 93]
""" """
items = [] items = []
totalParts = self._jrdd.partitions().size() totalParts = self.getNumPartitions()
partsScanned = 0 partsScanned = 0


while len(items) < num and partsScanned < totalParts: while len(items) < num and partsScanned < totalParts:
Expand Down Expand Up @@ -1260,7 +1260,7 @@ def isEmpty(self):
>>> sc.parallelize([1]).isEmpty() >>> sc.parallelize([1]).isEmpty()
False False
""" """
return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 return self.getNumPartitions() == 0 or len(self.take(1)) == 0


def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
""" """
Expand Down Expand Up @@ -2235,11 +2235,9 @@ def _prepare_for_python_RDD(sc, command, obj=None):
ser = CloudPickleSerializer() ser = CloudPickleSerializer()
pickled_command = ser.dumps((command, sys.version_info[:2])) pickled_command = ser.dumps((command, sys.version_info[:2]))
if len(pickled_command) > (1 << 20): # 1M if len(pickled_command) > (1 << 20): # 1M
# The broadcast will have same life cycle as created PythonRDD
broadcast = sc.broadcast(pickled_command) broadcast = sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast) pickled_command = ser.dumps(broadcast)
# tracking the life cycle by obj
if obj is not None:
obj._broadcast = broadcast
broadcast_vars = ListConverter().convert( broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in sc._pickled_broadcast_vars], [x._jbroadcast for x in sc._pickled_broadcast_vars],
sc._gateway._gateway_client) sc._gateway._gateway_client)
Expand Down Expand Up @@ -2294,12 +2292,9 @@ def pipeline_func(split, iterator):
self._jrdd_deserializer = self.ctx.serializer self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False self._bypass_serializer = False
self.partitioner = prev.partitioner if self.preservesPartitioning else None self.partitioner = prev.partitioner if self.preservesPartitioning else None
self._broadcast = None


def __del__(self): def getNumPartitions(self):
if self._broadcast: return self._prev_jrdd.partitions().size()
self._broadcast.unpersist()
self._broadcast = None


@property @property
def _jrdd(self): def _jrdd(self):
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/tests.py
Expand Up @@ -550,10 +550,8 @@ def test_large_closure(self):
data = [float(i) for i in xrange(N)] data = [float(i) for i in xrange(N)]
rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
self.assertEquals(N, rdd.first()) self.assertEquals(N, rdd.first())
self.assertTrue(rdd._broadcast is not None) # regression test for SPARK-6886
rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
self.assertEqual(1, rdd.first())
self.assertTrue(rdd._broadcast is None)


def test_zip_with_different_serializers(self): def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5)) a = self.sc.parallelize(range(5))
Expand Down

0 comments on commit f11288d

Please sign in to comment.