Skip to content

Commit

Permalink
[PYSPARK] Updates to Accumulators
Browse files Browse the repository at this point in the history
(cherry picked from commit 15fc237)
  • Loading branch information
LucaCanali authored and squito committed Aug 3, 2018
1 parent a3eb07d commit b2e0f68
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 17 deletions.
12 changes: 9 additions & 3 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,9 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
*/
private[spark] class PythonAccumulatorV2(
@transient private val serverHost: String,
private val serverPort: Int)
extends CollectionAccumulator[Array[Byte]] {
private val serverPort: Int,
private val secretToken: String)
extends CollectionAccumulator[Array[Byte]] with Logging{

Utils.checkHost(serverHost, "Expected hostname")

Expand All @@ -902,12 +903,17 @@ private[spark] class PythonAccumulatorV2(
private def openSocket(): Socket = synchronized {
if (socket == null || socket.isClosed) {
socket = new Socket(serverHost, serverPort)
logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort")
// send the secret just for the initial authentication when opening a new connection
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
}
socket
}

// Need to override so the types match with PythonFunction
override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort)
override def copyAndReset(): PythonAccumulatorV2 = {
new PythonAccumulatorV2(serverHost, serverPort, secretToken)
}

override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized {
val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2]
Expand Down
53 changes: 41 additions & 12 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,49 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):

def handle(self):
from pyspark.accumulators import _accumulatorRegistry
while not self.server.server_shutdown:
# Poll every 1 second for new data -- don't block in case of shutdown.
r, _, _ = select.select([self.rfile], [], [], 1)
if self.rfile in r:
num_updates = read_int(self.rfile)
for _ in range(num_updates):
(aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
auth_token = self.server.auth_token

def poll(func):
while not self.server.server_shutdown:
# Poll every 1 second for new data -- don't block in case of shutdown.
r, _, _ = select.select([self.rfile], [], [], 1)
if self.rfile in r:
if func():
break

def accum_updates():
num_updates = read_int(self.rfile)
for _ in range(num_updates):
(aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
return False

def authenticate_and_accum_updates():
received_token = self.rfile.read(len(auth_token))
if isinstance(received_token, bytes):
received_token = received_token.decode("utf-8")
if (received_token == auth_token):
accum_updates()
# we've authenticated, we can break out of the first loop now
return True
else:
raise Exception(
"The value of the provided token to the AccumulatorServer is not correct.")

# first we keep polling till we've received the authentication token
poll(authenticate_and_accum_updates)
# now we've authenticated, don't need to check for the token anymore
poll(accum_updates)


class AccumulatorServer(SocketServer.TCPServer):

def __init__(self, server_address, RequestHandlerClass, auth_token):
SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass)
self.auth_token = auth_token

"""
A simple TCP server that intercepts shutdown() in order to interrupt
our continuous polling on the handler.
Expand All @@ -254,9 +283,9 @@ def shutdown(self):
self.server_close()


def _start_update_server():
def _start_update_server(auth_token):
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,

# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
self._accumulatorServer = accumulators._start_update_server()
auth_token = self._gateway.gateway_parameters.auth_token
self._accumulatorServer = accumulators._start_update_server(auth_token)
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port)
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
self._jsc.sc().register(self._javaAccumulator)

self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
Expand Down

0 comments on commit b2e0f68

Please sign in to comment.