Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12717][PYTHON][BRANCH-2.1] Adding thread-safe broadcast pickle registry #18825

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import gc
from tempfile import NamedTemporaryFile
import threading

from pyspark.cloudpickle import print_exec

Expand Down Expand Up @@ -137,6 +138,24 @@ def __reduce__(self):
return _from_id, (self._jbroadcast.id(),)


class BroadcastPickleRegistry(threading.local):
""" Thread-local registry for broadcast variables that have been pickled
"""

def __init__(self):
self.__dict__.setdefault("_registry", set())

def __iter__(self):
for bcast in self._registry:
yield bcast

def add(self, bcast):
self._registry.add(bcast)

def clear(self):
self._registry.clear()


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
Expand Down Expand Up @@ -200,7 +200,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
# send.
self._pickled_broadcast_vars = set()
self._pickled_broadcast_vars = BroadcastPickleRegistry()

SparkFiles._sc = self
root_dir = SparkFiles.getRootDirectory()
Expand Down
44 changes: 44 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,50 @@ def test_multiple_broadcasts(self):
self.assertEqual(N, size)
self.assertEqual(checksum, csum)

def test_multithread_broadcast_pickle(self):
import threading

b1 = self.sc.broadcast(list(range(3)))
b2 = self.sc.broadcast(list(range(3)))

def f1():
return b1.value

def f2():
return b2.value

funcs_num_pickled = {f1: None, f2: None}

def do_pickle(f, sc):
command = (f, None, sc.serializer, sc.serializer)
ser = CloudPickleSerializer()
ser.dumps(command)

def process_vars(sc):
broadcast_vars = list(sc._pickled_broadcast_vars)
num_pickled = len(broadcast_vars)
sc._pickled_broadcast_vars.clear()
return num_pickled

def run(f, sc):
do_pickle(f, sc)
funcs_num_pickled[f] = process_vars(sc)

# pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
do_pickle(f1, self.sc)

# run all for f2, should only add/count/clear b2 from worker thread local storage
t = threading.Thread(target=run, args=(f2, self.sc))
t.start()
t.join()

# count number of vars pickled in main thread, only b1 should be counted and cleared
funcs_num_pickled[f1] = process_vars(self.sc)

self.assertEqual(funcs_num_pickled[f1], 1)
self.assertEqual(funcs_num_pickled[f2], 1)
self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)

def test_large_closure(self):
N = 200000
data = [float(i) for i in xrange(N)]
Expand Down