From f11288d5272bc18585b8cad4ee3bd59eade7c296 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 15 Apr 2015 12:58:02 -0700 Subject: [PATCH] [SPARK-6886] [PySpark] fix big closure with shuffle Currently, the created broadcast object will have same life cycle as RDD in Python. For multistage jobs, an PythonRDD will be created in JVM and the RDD in Python may be GCed, then the broadcast will be destroyed in JVM before the PythonRDD. This PR change to use PythonRDD to track the lifecycle of the broadcast object. It also have a refactor about getNumPartitions() to avoid unnecessary creation of PythonRDD, which could be heavy. cc JoshRosen Author: Davies Liu Closes #5496 from davies/big_closure and squashes the following commits: 9a0ea4c [Davies Liu] fix big closure with shuffle --- python/pyspark/rdd.py | 15 +++++---------- python/pyspark/tests.py | 6 ++---- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c9ac95d117574..93e658eded9e2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1197,7 +1197,7 @@ def take(self, num): [91, 92, 93] """ items = [] - totalParts = self._jrdd.partitions().size() + totalParts = self.getNumPartitions() partsScanned = 0 while len(items) < num and partsScanned < totalParts: @@ -1260,7 +1260,7 @@ def isEmpty(self): >>> sc.parallelize([1]).isEmpty() False """ - return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 + return self.getNumPartitions() == 0 or len(self.take(1)) == 0 def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -2235,11 +2235,9 @@ def _prepare_for_python_RDD(sc, command, obj=None): ser = CloudPickleSerializer() pickled_command = ser.dumps((command, sys.version_info[:2])) if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) - # tracking the life cycle by obj - if obj is not None: - obj._broadcast = broadcast broadcast_vars = ListConverter().convert( [x._jbroadcast for x in sc._pickled_broadcast_vars], sc._gateway._gateway_client) @@ -2294,12 +2292,9 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None - self._broadcast = None - def __del__(self): - if self._broadcast: - self._broadcast.unpersist() - self._broadcast = None + def getNumPartitions(self): + return self._prev_jrdd.partitions().size() @property def _jrdd(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b938b9ce12395..ee67e80d539f8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -550,10 +550,8 @@ def test_large_closure(self): data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) self.assertEquals(N, rdd.first()) - self.assertTrue(rdd._broadcast is not None) - rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) - self.assertEqual(1, rdd.first()) - self.assertTrue(rdd._broadcast is None) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5))