From 62261891e7cfd89935676f77eddf2e2a66f7f9d2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Aug 2014 17:25:55 -0700 Subject: [PATCH 1/6] improve large broadcast Passing large object by py4j is very slow (cost much memory), so pass broadcast objects via files (similar to parallelize()). Add an option to keep object in driver (it's False by default) to save memory in driver. --- .../apache/spark/api/python/PythonRDD.scala | 9 +++++++++ python/pyspark/context.py | 19 +++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0b5322c6fb965..fd67523ecbbea 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -315,6 +315,15 @@ private[spark] object PythonRDD extends Logging { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } + def readBroadcastFromFile(sc: JavaSparkContext, filename: String): + Broadcast[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + sc.broadcast(obj) + } + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { // The right way to implement this would be to use TypeTags to get the full // type of T. Since I don't want to introduce breaking changes throughout the diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4001ecab5ea00..ce5ad856575f3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -562,17 +562,24 @@ def union(self, rdds): rest = ListConverter().convert(rest, self._gateway._gateway_client) return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) - def broadcast(self, value): + def broadcast(self, value, keep=False): """ Broadcast a read-only variable to the cluster, returning a L{Broadcast} - object for reading it in distributed functions. The variable will be - sent to each cluster only once. + object for reading it in distributed functions. The variable will + be sent to each cluster only once. + + :keep: Keep the `value` in driver or not. """ pickleSer = PickleSerializer() - pickled = pickleSer.dumps(value) - jbroadcast = self._jsc.broadcast(bytearray(pickled)) - return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + # pass large object by py4j is very slow and need much memory + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) + pickleSer.dump_stream([value], tempFile) + tempFile.close() + jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) + os.unlink(tempFile.name) + return Broadcast(jbroadcast.id(), value if keep else None, + jbroadcast, self._pickled_broadcast_vars) def accumulator(self, value, accum_param=None): """ From e93cf4b660d8612a4c34b05b630b507d2d89325d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Aug 2014 23:57:56 -0700 Subject: [PATCH 2/6] address comments: add test --- .../main/scala/org/apache/spark/api/python/PythonRDD.scala | 3 +-- python/pyspark/tests.py | 7 +++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fd67523ecbbea..1e27912dab638 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -315,8 +315,7 @@ private[spark] object PythonRDD extends Logging { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def readBroadcastFromFile(sc: JavaSparkContext, filename: String): - Broadcast[Array[Byte]] = { + def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val length = file.readInt() val obj = new Array[Byte](length) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 22b51110ed671..f1fece998cd54 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -323,6 +323,13 @@ def test_namedtuple_in_rdd(self): theDoes = self.sc.parallelize([jon, jane]) self.assertEquals([jon, jane], theDoes.collect()) + def test_large_broadcast(self): + N = 100000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 270MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEquals(N, m) + class TestIO(PySparkTestCase): From 9a7161f146a2eee5f3d10eb2a2084f1cd0c60dec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Aug 2014 09:05:47 -0700 Subject: [PATCH 3/6] fix doc tests --- python/pyspark/broadcast.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index f3e64989ed564..20dbb4994f156 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -19,18 +19,13 @@ >>> from pyspark.context import SparkContext >>> sc = SparkContext('local', 'test') >>> b = sc.broadcast([1, 2, 3, 4, 5]) ->>> b.value -[1, 2, 3, 4, 5] - ->>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.bid] = b ->>> from cPickle import dumps, loads ->>> loads(dumps(b)).value -[1, 2, 3, 4, 5] - >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +>>> b = sc.broadcast([1, 2, 3, 4, 5], keep=True) +>>> b.value +[1, 2, 3, 4, 5] + >>> large_broadcast = sc.broadcast(list(range(10000))) """ # Holds broadcasted data received from Java, keyed by its id. @@ -66,3 +61,8 @@ def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) + + +if __name__ == "__main__": + import doctest + doctest.testmod() From c7baa8c33b21d12f45afc83389ab99c3f96b76bb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Aug 2014 09:48:02 -0700 Subject: [PATCH 4/6] compress serrialized broadcast and command --- python/pyspark/context.py | 6 +++--- python/pyspark/rdd.py | 5 +++-- python/pyspark/serializers.py | 17 +++++++++++++++++ python/pyspark/worker.py | 5 +++-- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ce5ad856575f3..a143e249b3309 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer + PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -571,10 +571,10 @@ def broadcast(self, value, keep=False): :keep: Keep the `value` in driver or not. """ - pickleSer = PickleSerializer() + ser = CompressedSerializer(PickleSerializer()) # pass large object by py4j is very slow and need much memory tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - pickleSer.dump_stream([value], tempFile) + ser.dump_stream([value], tempFile) tempFile.close() jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) os.unlink(tempFile.name) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 756e8f35fb03d..9541f051f6167 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -35,7 +35,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long + PickleSerializer, pack_long, CompressedSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -1809,7 +1809,8 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) - pickled_command = CloudPickleSerializer().dumps(command) + ser = CompressedSerializer(CloudPickleSerializer()) + pickled_command = ser.dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index df90cafb245bf..74870c0edcf99 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -67,6 +67,7 @@ import sys import types import collections +import zlib from pyspark import cloudpickle @@ -403,6 +404,22 @@ def loads(self, obj): raise ValueError("invalid sevialization type: %s" % _type) +class CompressedSerializer(FramedSerializer): + """ + compress the serialized data + """ + + def __init__(self, serializer): + FramedSerializer.__init__(self) + self.serializer = serializer + + def dumps(self, obj): + return zlib.compress(self.serializer.dumps(obj), 1) + + def loads(self, obj): + return self.serializer.loads(zlib.decompress(obj)) + + class UTF8Deserializer(Serializer): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2770f63059853..62659cd30b146 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,10 +30,11 @@ from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ + CompressedSerializer -pickleSer = PickleSerializer() +pickleSer = CompressedSerializer(PickleSerializer()) utf8_deserializer = UTF8Deserializer() From db3f232ed28a5795a6c81c3f9acdb1349474b34d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Aug 2014 13:03:57 -0700 Subject: [PATCH 5/6] fix serialization of accumulator add better message when try to access Broadcast.value in driver. --- python/pyspark/broadcast.py | 13 +++++++++++-- python/pyspark/context.py | 4 ++-- python/pyspark/worker.py | 7 ++++--- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 20dbb4994f156..82fda45bf472f 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -47,21 +47,30 @@ class Broadcast(object): Access its value through C{.value}. """ - def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, pickle_registry=None, keep=True): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.value = value self.bid = bid + if keep: + self.value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry + self.keep = keep def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) + def __getattr__(self, item): + if item == 'value' and not self.keep: + raise Exception("please create broadcast with keep=True to make" + " it accessable in driver") + + raise AttributeError(item) + if __name__ == "__main__": import doctest diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a143e249b3309..91dc5e3e5872b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -578,8 +578,8 @@ def broadcast(self, value, keep=False): tempFile.close() jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) os.unlink(tempFile.name) - return Broadcast(jbroadcast.id(), value if keep else None, - jbroadcast, self._pickled_broadcast_vars) + return Broadcast(jbroadcast.id(), value, jbroadcast, + self._pickled_broadcast_vars, keep) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 62659cd30b146..77a9c4a0e0677 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,7 +34,7 @@ CompressedSerializer -pickleSer = CompressedSerializer(PickleSerializer()) +pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -66,12 +66,13 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) + ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = pickleSer._read_with_length(infile) + value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) - command = pickleSer._read_with_length(infile) + command = ser._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) From e06df4a8c211f53a5b7d176c6ec655033f1419ee Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 14 Aug 2014 22:47:58 -0700 Subject: [PATCH 6/6] load broadcast from disk in driver automatically add Broadcast.unpersist() --- python/pyspark/broadcast.py | 30 ++++++++++++++++++++---------- python/pyspark/context.py | 7 +++---- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 82fda45bf472f..675a2fcd2ff4e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -19,15 +19,18 @@ >>> from pyspark.context import SparkContext >>> sc = SparkContext('local', 'test') >>> b = sc.broadcast([1, 2, 3, 4, 5]) ->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() -[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] - ->>> b = sc.broadcast([1, 2, 3, 4, 5], keep=True) >>> b.value [1, 2, 3, 4, 5] +>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() +[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +>>> b.unpersist() >>> large_broadcast = sc.broadcast(list(range(10000))) """ +import os + +from pyspark.serializers import CompressedSerializer, PickleSerializer + # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} @@ -47,27 +50,34 @@ class Broadcast(object): Access its value through C{.value}. """ - def __init__(self, bid, value, java_broadcast=None, pickle_registry=None, keep=True): + def __init__(self, bid, value, java_broadcast=None, + pickle_registry=None, path=None): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ self.bid = bid - if keep: + if path is None: self.value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry - self.keep = keep + self.path = path + + def unpersist(self, blocking=False): + self._jbroadcast.unpersist(blocking) + os.unlink(self.path) def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) def __getattr__(self, item): - if item == 'value' and not self.keep: - raise Exception("please create broadcast with keep=True to make" - " it accessable in driver") + if item == 'value' and self.path is not None: + ser = CompressedSerializer(PickleSerializer()) + value = ser.load_stream(open(self.path)).next() + self.value = value + return value raise AttributeError(item) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 91dc5e3e5872b..182b0249fa900 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -562,7 +562,7 @@ def union(self, rdds): rest = ListConverter().convert(rest, self._gateway._gateway_client) return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) - def broadcast(self, value, keep=False): + def broadcast(self, value): """ Broadcast a read-only variable to the cluster, returning a L{Broadcast} @@ -577,9 +577,8 @@ def broadcast(self, value, keep=False): ser.dump_stream([value], tempFile) tempFile.close() jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) - os.unlink(tempFile.name) - return Broadcast(jbroadcast.id(), value, jbroadcast, - self._pickled_broadcast_vars, keep) + return Broadcast(jbroadcast.id(), None, jbroadcast, + self._pickled_broadcast_vars, tempFile.name) def accumulator(self, value, accum_param=None): """