Skip to content

Commit

Permalink
Added thread-safe broadcast pickle registry, using thread local storage
Browse files Browse the repository at this point in the history
added regression test for multithreaded broadcast pickle
  • Loading branch information
BryanCutler committed Aug 2, 2017
1 parent b31b302 commit cc239a4
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
19 changes: 19 additions & 0 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import gc
from tempfile import NamedTemporaryFile
import threading

from pyspark.cloudpickle import print_exec

Expand Down Expand Up @@ -137,6 +138,24 @@ def __reduce__(self):
return _from_id, (self._jbroadcast.id(),)


class BroadcastPickleRegistry(threading.local):
""" Thread-local registry for broadcast variables that have been pickled
"""

def __init__(self):
self.__dict__.setdefault("_registry", set())

def __iter__(self):
for bcast in self._registry:
yield bcast

def add(self, bcast):
self._registry.add(bcast)

def clear(self):
self._registry.clear()


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
Expand Down Expand Up @@ -200,7 +200,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
# send.
self._pickled_broadcast_vars = set()
self._pickled_broadcast_vars = BroadcastPickleRegistry()

SparkFiles._sc = self
root_dir = SparkFiles.getRootDirectory()
Expand Down
44 changes: 44 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,50 @@ def test_multiple_broadcasts(self):
self.assertEqual(N, size)
self.assertEqual(checksum, csum)

def test_multithread_broadcast_pickle(self):
import threading

b1 = self.sc.broadcast(list(range(3)))
b2 = self.sc.broadcast(list(range(3)))

def f1():
return b1.value

def f2():
return b2.value

funcs_num_pickled = {f1: None, f2: None}

def do_pickle(f, sc):
command = (f, None, sc.serializer, sc.serializer)
ser = CloudPickleSerializer()
ser.dumps(command)

def process_vars(sc):
broadcast_vars = list(sc._pickled_broadcast_vars)
num_pickled = len(broadcast_vars)
sc._pickled_broadcast_vars.clear()
return num_pickled

def run(f, sc):
do_pickle(f, sc)
funcs_num_pickled[f] = process_vars(sc)

# pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
do_pickle(f1, self.sc)

# run all for f2, should only add/count/clear b2 from worker thread local storage
t = threading.Thread(target=run, args=(f2, self.sc))
t.start()
t.join()

# count number of vars pickled in main thread, only b1 should be counted and cleared
funcs_num_pickled[f1] = process_vars(self.sc)

self.assertEqual(funcs_num_pickled[f1], 1)
self.assertEqual(funcs_num_pickled[f2], 1)
self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)

def test_large_closure(self):
N = 200000
data = [float(i) for i in xrange(N)]
Expand Down

0 comments on commit cc239a4

Please sign in to comment.