From e9892b446ac99c6743bdd6cb9710f9a5f8bb86b0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 30 Jul 2014 20:50:32 -0700 Subject: [PATCH 1/7] [WIP] [SPARK-2764] Simplify daemon.py process structure. Curently, daemon.py forks a pool of numProcessors subprocesses, and those processes fork themselves again to create the actual Python worker processes that handle data. I think that this extra layer of indirection is unnecessary and adds a lot of complexity. This commit attemps to remove this middle layer of subprocesses by launching the workers directly from daemon.py. See https://github.com/mesos/spark/pull/563 for the original PR that added daemon.py, where I raise some issues with the current design. --- python/pyspark/daemon.py | 131 ++++++++++++++------------------------- 1 file changed, 46 insertions(+), 85 deletions(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 8a5873ded2b8b..ffbb537f99c86 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -17,6 +17,7 @@ import os import signal +import select import socket import sys import traceback @@ -28,11 +29,6 @@ from pyspark.worker import main as worker_main from pyspark.serializers import write_int -try: - POOLSIZE = multiprocessing.cpu_count() -except NotImplementedError: - POOLSIZE = 4 - exit_flag = multiprocessing.Value(c_bool, False) @@ -50,29 +46,16 @@ def compute_real_exit_code(exit_code): return 1 -def worker(listen_sock): +def worker(sock): + """ + Called by a worker process after the fork(). + """ # Redirect stdout to stderr os.dup2(2, 1) sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 - # Manager sends SIGHUP to request termination of workers in the pool - def handle_sighup(*args): - assert should_exit() - signal.signal(SIGHUP, handle_sighup) - - # Cleanup zombie children - def handle_sigchld(*args): - pid = status = None - try: - while (pid, status) != (0, 0): - pid, status = os.waitpid(0, os.WNOHANG) - except EnvironmentError as err: - if err.errno == EINTR: - # retry - handle_sigchld() - elif err.errno != ECHILD: - raise - signal.signal(SIGCHLD, handle_sigchld) + signal.signal(SIGHUP, SIG_DFL) + signal.signal(SIGCHLD, SIG_DFL) # Blocks until the socket is closed by draining the input stream # until it raises an exception or returns EOF. @@ -85,55 +68,21 @@ def waitSocketClose(sock): except: pass - # Handle clients - while not should_exit(): - # Wait until a client arrives or we have to exit - sock = None - while not should_exit() and sock is None: - try: - sock, addr = listen_sock.accept() - except EnvironmentError as err: - if err.errno != EINTR: - raise - - if sock is not None: - # Fork a child to handle the client. - # The client is handled in the child so that the manager - # never receives SIGCHLD unless a worker crashes. - if os.fork() == 0: - # Leave the worker pool - signal.signal(SIGHUP, SIG_DFL) - signal.signal(SIGCHLD, SIG_DFL) - listen_sock.close() - # Read the socket using fdopen instead of socket.makefile() because the latter - # seems to be very slow; note that we need to dup() the file descriptor because - # otherwise writes also cause a seek that makes us miss data on the read side. - infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - exit_code = 0 - try: - worker_main(infile, outfile) - except SystemExit as exc: - exit_code = exc.code - finally: - outfile.flush() - # The Scala side will close the socket upon task completion. - waitSocketClose(sock) - os._exit(compute_real_exit_code(exit_code)) - else: - sock.close() - - -def launch_worker(listen_sock): - if os.fork() == 0: - try: - worker(listen_sock) - except Exception as err: - traceback.print_exc() - os._exit(1) - else: - assert should_exit() - os._exit(0) + # Read the socket using fdopen instead of socket.makefile() because the latter + # seems to be very slow; note that we need to dup() the file descriptor because + # otherwise writes also cause a seek that makes us miss data on the read side. + infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + exit_code = 0 + try: + worker_main(infile, outfile) + except SystemExit as exc: + exit_code = exc.code + finally: + outfile.flush() + # The Scala side will close the socket upon task completion. + waitSocketClose(sock) + os._exit(compute_real_exit_code(exit_code)) def manager(): @@ -143,15 +92,10 @@ def manager(): # Create a listening socket on the AF_INET loopback interface listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) - listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) + listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) - # Launch initial worker pool - for idx in range(POOLSIZE): - launch_worker(listen_sock) - listen_sock.close() - def shutdown(): global exit_flag exit_flag.value = True @@ -176,13 +120,30 @@ def handle_sigchld(*args): try: while not should_exit(): try: - # Spark tells us to exit by closing stdin - if os.read(0, 512) == '': - shutdown() - except EnvironmentError as err: - if err.errno != EINTR: - shutdown() + ready_fds = select.select([0, listen_sock], [], [])[0] + except select.error as ex: + if ex[0] == 4: + continue + else: raise + if 0 in ready_fds: + # Spark told us to exit by closing stdin + shutdown() + if listen_sock in ready_fds: + sock, addr = listen_sock.accept() + # Launch a worker process + if os.fork() == 0: + listen_sock.close() + try: + worker(sock) + except: + traceback.print_exc() + os._exit(1) + else: + assert should_exit() + os._exit(0) + else: + sock.close() finally: signal.signal(SIGTERM, SIG_DFL) exit_flag.value = True From 4e0fab85b1583fd98f793e5ada4fb77c7ba650c3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 31 Jul 2014 12:10:11 -0700 Subject: [PATCH 2/7] Remove shared-memory exit_flag; don't die on worker death. --- python/pyspark/daemon.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index ffbb537f99c86..89d4280a3f8ce 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -15,31 +15,22 @@ # limitations under the License. # +import numbers import os import signal import select import socket import sys import traceback -import multiprocessing -from ctypes import c_bool from errno import EINTR, ECHILD from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main from pyspark.serializers import write_int -exit_flag = multiprocessing.Value(c_bool, False) - - -def should_exit(): - global exit_flag - return exit_flag.value - def compute_real_exit_code(exit_code): # SystemExit's code can be integer or string, but os._exit only accepts integers - import numbers if isinstance(exit_code, numbers.Integral): return exit_code else: @@ -108,8 +99,9 @@ def shutdown(): def handle_sigchld(*args): try: pid, status = os.waitpid(0, os.WNOHANG) - if status != 0 and not should_exit(): - raise RuntimeError("worker crashed: %s, %s" % (pid, status)) + if status != 0: + msg = "worker %s crashed abruptly with exit status %s" % (pid, status) + print >> sys.stderr, msg except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise @@ -118,7 +110,7 @@ def handle_sigchld(*args): # Initialization complete sys.stdout.close() try: - while not should_exit(): + while True: try: ready_fds = select.select([0, listen_sock], [], [])[0] except select.error as ex: @@ -140,13 +132,11 @@ def handle_sigchld(*args): traceback.print_exc() os._exit(1) else: - assert should_exit() os._exit(0) else: sock.close() finally: signal.signal(SIGTERM, SIG_DFL) - exit_flag.value = True # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) From 855453627f8aa2fb9c9cc4cec6e2f11795cd48bd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 31 Jul 2014 16:52:49 -0700 Subject: [PATCH 3/7] =?UTF-8?q?Fix=20daemon=E2=80=99s=20shutdown();=20log?= =?UTF-8?q?=20shutdown=20reason.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/pyspark/daemon.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 89d4280a3f8ce..680811b176354 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -87,13 +87,17 @@ def manager(): listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) - def shutdown(): - global exit_flag - exit_flag.value = True + def shutdown(code): + signal.signal(SIGTERM, SIG_DFL) + # Send SIGHUP to notify workers of shutdown + os.kill(0, SIGHUP) + exit(code) - # Gracefully exit on SIGTERM, don't die on SIGHUP - signal.signal(SIGTERM, lambda signum, frame: shutdown()) - signal.signal(SIGHUP, SIG_IGN) + def sig_term(signum, frame): + print >> sys.stderr, "daemon.py shutting down due to SIGTERM" + shutdown(1) + signal.signal(SIGTERM, sig_term) # Gracefully exit on SIGTERM + signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP # Cleanup zombie children def handle_sigchld(*args): @@ -120,7 +124,8 @@ def handle_sigchld(*args): raise if 0 in ready_fds: # Spark told us to exit by closing stdin - shutdown() + print >> sys.stderr, "daemon.py shutting down because Java closed stdin" + shutdown(0) if listen_sock in ready_fds: sock, addr = listen_sock.accept() # Launch a worker process @@ -136,9 +141,8 @@ def handle_sigchld(*args): else: sock.close() finally: - signal.signal(SIGTERM, SIG_DFL) - # Send SIGHUP to notify workers of shutdown - os.kill(0, SIGHUP) + print >> sys.stderr, "daemon.py shutting down due to uncaught exception" + shutdown(1) if __name__ == '__main__': From 282c2c48d5ca4072fab3c5fb04068efbb440e950 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 31 Jul 2014 18:37:00 -0700 Subject: [PATCH 4/7] Remove daemon.py exit logging, since it caused problems: After including these logging statements, orphaned workers might stay alive after the driver died. My current theory is that the print to sys.stderr failed due to Java closing the file and an exception was thrown that managed to propagate to to the uncaught exception handler, causing daemon.py to exit before it could send SIGHUP to its children. --- python/pyspark/daemon.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 680811b176354..962e4775c339e 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -47,6 +47,7 @@ def worker(sock): signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) + signal.signal(SIGTERM, SIG_DFL) # Blocks until the socket is closed by draining the input stream # until it raises an exception or returns EOF. @@ -93,10 +94,9 @@ def shutdown(code): os.kill(0, SIGHUP) exit(code) - def sig_term(signum, frame): - print >> sys.stderr, "daemon.py shutting down due to SIGTERM" + def handle_sigterm(*args): shutdown(1) - signal.signal(SIGTERM, sig_term) # Gracefully exit on SIGTERM + signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP # Cleanup zombie children @@ -124,7 +124,6 @@ def handle_sigchld(*args): raise if 0 in ready_fds: # Spark told us to exit by closing stdin - print >> sys.stderr, "daemon.py shutting down because Java closed stdin" shutdown(0) if listen_sock in ready_fds: sock, addr = listen_sock.accept() @@ -141,7 +140,6 @@ def handle_sigchld(*args): else: sock.close() finally: - print >> sys.stderr, "daemon.py shutting down due to uncaught exception" shutdown(1) From b79254d838469ddcb13db6eefa938701052d9773 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 Aug 2014 12:54:06 -0700 Subject: [PATCH 5/7] Detect failed fork() calls; improve error logging. --- .../api/python/PythonWorkerFactory.scala | 12 ++++++-- python/pyspark/daemon.py | 28 +++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 759cbe2c46c52..8357b41a0e16c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.python import java.io.{DataInputStream, InputStream, OutputStreamWriter} -import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.net._ import scala.collection.JavaConversions._ @@ -64,10 +64,16 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Attempt to connect, restart and retry once if it fails try { - new Socket(daemonHost, daemonPort) + val socket = new Socket(daemonHost, daemonPort) + val launchStatus = new DataInputStream(socket.getInputStream).readInt() + if (launchStatus != 0) { + logWarning("Python daemon failed to launch worker") + } + socket } catch { case exc: SocketException => - logWarning("Python daemon unexpectedly quit, attempting to restart") + logWarning("Failed to open socket to Python daemon:", exc) + logWarning("Assuming that daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() new Socket(daemonHost, daemonPort) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 962e4775c339e..4c7aefd7f9922 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -67,6 +67,8 @@ def waitSocketClose(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: + write_int(0, outfile) # Acknowledge that the fork was successful + outfile.flush() worker_main(infile, outfile) except SystemExit as exc: exit_code = exc.code @@ -128,16 +130,24 @@ def handle_sigchld(*args): if listen_sock in ready_fds: sock, addr = listen_sock.accept() # Launch a worker process - if os.fork() == 0: - listen_sock.close() - try: - worker(sock) - except: - traceback.print_exc() - os._exit(1) + try: + fork_return_code = os.fork() + if fork_return_code == 0: + listen_sock.close() + try: + worker(sock) + except: + traceback.print_exc() + os._exit(1) + else: + os._exit(0) else: - os._exit(0) - else: + sock.close() + except OSError as e: + print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e + outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + write_int(-1, outfile) # Signal that the fork failed + outfile.flush() sock.close() finally: shutdown(1) From 5495dffe5b11bb90602757f71d1a43aa69e0f258 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 Aug 2014 15:22:53 -0700 Subject: [PATCH 6/7] Throw IllegalStateException if worker launch fails. --- .../org/apache/spark/api/python/PythonWorkerFactory.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 8357b41a0e16c..15fe8a9be6bfe 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.python import java.io.{DataInputStream, InputStream, OutputStreamWriter} -import java.net._ +import java.net.{InetAddress, ServerSocket, Socket, SocketException} import scala.collection.JavaConversions._ @@ -67,7 +67,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val socket = new Socket(daemonHost, daemonPort) val launchStatus = new DataInputStream(socket.getInputStream).readInt() if (launchStatus != 0) { - logWarning("Python daemon failed to launch worker") + throw new IllegalStateException("Python daemon failed to launch worker") } socket } catch { From 5abbcb9ab219796910f86bfb26343e4cd683a52a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 Aug 2014 18:12:06 -0700 Subject: [PATCH 7/7] Replace magic number: 4 -> EINTR --- python/pyspark/daemon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 4c7aefd7f9922..9fde0dde0f4b4 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -120,7 +120,7 @@ def handle_sigchld(*args): try: ready_fds = select.select([0, listen_sock], [], [])[0] except select.error as ex: - if ex[0] == 4: + if ex[0] == EINTR: continue else: raise