diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 462e09466bfa6..4273fdc5b3fb3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -588,19 +588,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) + /** + * We try to reuse a single Socket to transfer accumulator updates, as they are all added + * by the DAGScheduler's single-threaded actor anyway. + */ + @transient var socket: Socket = _ + + def openSocket(): Socket = synchronized { + if (socket == null || socket.isClosed) { + socket = new Socket(serverHost, serverPort) + } + socket + } + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = { + : JList[Array[Byte]] = synchronized { if (serverHost == null) { // This happens on the worker node, where we just want to remember all the updates val1.addAll(val2) val1 } else { // This happens on the master, where we pass the updates to Python through a socket - val socket = new Socket(serverHost, serverPort) - // SPARK-2282: Immediately reuse closed sockets because we create one per task. - socket.setReuseAddress(true) + val socket = openSocket() val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) @@ -614,7 +625,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: if (byteRead == -1) { throw new SparkException("EOF reached before Python server acknowledged") } - socket.close() null } } diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 2204e9c9ca701..45d36e5d0e764 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -86,6 +86,7 @@ Exception:... """ +import select import struct import SocketServer import threading @@ -209,19 +210,38 @@ def addInPlace(self, value1, value2): class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + """ + This handler will keep polling updates from the same socket until the + server is shutdown. + """ + def handle(self): from pyspark.accumulators import _accumulatorRegistry - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + +class AccumulatorServer(SocketServer.TCPServer): + """ + A simple TCP server that intercepts shutdown() in order to interrupt + our continuous polling on the handler. + """ + server_shutdown = False + def shutdown(self): + self.server_shutdown = True + SocketServer.TCPServer.shutdown(self) def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start()