Skip to content

Commit

Permalink
SPARK-2282: Reuse Socket for sending accumulator updates to Pyspark
Browse files Browse the repository at this point in the history
Prior to this change, every PySpark task completion opened a new
socket to the accumulator server, passed its updates through, and
then quit. I'm not entirely sure why PySpark always sends accumulator
updates, but regardless this causes a very rapid buildup of ephemeral
TCP connections that remain in the TCP_WAIT state for around a minute
before being cleaned up.

Rather than trying to allow these sockets to be cleaned up faster, this
patch simply reuses the connection between tasks completions (since they're
fed updates in a single-threaded manner by the DAGScheduler anyway).

The only tricky part here was making sure that the AccumulatorServer was
able to shutdown in a timely manner (i.e., stop polling for new data), and
this was accomplished via minor feats of magic.
  • Loading branch information
aarondav committed Jul 20, 2014
1 parent 1b10b81 commit b3e12f7
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
20 changes: 15 additions & 5 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -588,19 +588,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:

val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)

/**
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
* by the DAGScheduler's single-threaded actor anyway.
*/
@transient var socket: Socket = _

def openSocket(): Socket = synchronized {
if (socket == null || socket.isClosed) {
socket = new Socket(serverHost, serverPort)
}
socket
}

override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList

override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
: JList[Array[Byte]] = {
: JList[Array[Byte]] = synchronized {
if (serverHost == null) {
// This happens on the worker node, where we just want to remember all the updates
val1.addAll(val2)
val1
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
// SPARK-2282: Immediately reuse closed sockets because we create one per task.
socket.setReuseAddress(true)
val socket = openSocket()
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
Expand All @@ -614,7 +625,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
socket.close()
null
}
}
Expand Down
34 changes: 27 additions & 7 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
Exception:...
"""

import select
import struct
import SocketServer
import threading
Expand Down Expand Up @@ -209,19 +210,38 @@ def addInPlace(self, value1, value2):


class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
"""
This handler will keep polling updates from the same socket until the
server is shutdown.
"""

def handle(self):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
(aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
while not self.server.server_shutdown:
# Poll every 1 second for new data -- don't block in case of shutdown.
r, _, _ = select.select([self.rfile], [], [], 1)
if self.rfile in r:
num_updates = read_int(self.rfile)
for _ in range(num_updates):
(aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))

class AccumulatorServer(SocketServer.TCPServer):
"""
A simple TCP server that intercepts shutdown() in order to interrupt
our continuous polling on the handler.
"""
server_shutdown = False

def shutdown(self):
self.server_shutdown = True
SocketServer.TCPServer.shutdown(self)

def _start_update_server():
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
Expand Down

0 comments on commit b3e12f7

Please sign in to comment.