Skip to content

Commit

Permalink
[SPARK-36062][PYTHON] Try to capture faulthanlder when a Python worke…
Browse files Browse the repository at this point in the history
…r crashes

### What changes were proposed in this pull request?

Try to capture the error message from the `faulthandler` when the Python worker crashes.

### Why are the changes needed?

Currently, we just see an error message saying `"exited unexpectedly (crashed)"` when the UDFs causes the Python worker to crash by like segmentation fault.
We should take advantage of [`faulthandler`](https://docs.python.org/3/library/faulthandler.html) and try to capture the error message from the `faulthandler`.

### Does this PR introduce _any_ user-facing change?

Yes, when a Spark config `spark.python.worker.faulthandler.enabled` is `true`, the stack trace will be seen in the error message when the Python worker crashes.

```py
>>> def f():
...   import ctypes
...   ctypes.string_at(0)
...
>>> sc.parallelize([1]).map(lambda x: f()).count()
```

```
org.apache.spark.SparkException: Python worker exited unexpectedly (crashed): Fatal Python error: Segmentation fault

Current thread 0x000000010965b5c0 (most recent call first):
  File "/.../ctypes/__init__.py", line 525 in string_at
  File "<stdin>", line 3 in f
  File "<stdin>", line 1 in <lambda>
...
```

### How was this patch tested?

Added some tests, and manually.

Closes #33273 from ueshin/issues/SPARK-36062/faulthandler.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ueshin authored and HyukjinKwon committed Jul 9, 2021
1 parent a1ce649 commit 115b8a1
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 13 deletions.
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ class SparkEnv (
}

private[spark]
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
def createPythonWorker(
pythonExec: String,
envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
Expand Down
33 changes: 30 additions & 3 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io._
import java.net._
import java.nio.charset.StandardCharsets
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files => JavaFiles, Path}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

Expand Down Expand Up @@ -65,6 +66,15 @@ private[spark] object PythonEvalType {
}
}

private object BasePythonRunner {

private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")

private def faultHandlerLogPath(pid: Int): Path = {
new File(faultHandlerLogDir, pid.toString).toPath
}
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
Expand All @@ -83,6 +93,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
protected val simplifiedTraceback: Boolean = false

// All the Python functions should have the same exec, version and envvars.
Expand Down Expand Up @@ -143,7 +154,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
if (faultHandlerEnabled) {
envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString)
}

val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
// sure there is only one winner that is going to release or close the worker.
Expand Down Expand Up @@ -180,7 +196,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))

val stdoutIterator = newReaderIterator(
stream, writerThread, startTime, env, worker, releasedOrClosed, context)
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context)
new InterruptibleIterator(context, stdoutIterator)
}

Expand All @@ -197,6 +213,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT]

Expand Down Expand Up @@ -468,6 +485,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext)
extends Iterator[OUT] {
Expand Down Expand Up @@ -556,6 +574,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
logError("This may have been caused by a prior exception:", writerThread.exception.get)
throw writerThread.exception.get

case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) =>
val path = BasePythonRunner.faultHandlerLogPath(pid.get)
val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
JavaFiles.deleteIfExists(path)
throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof)

case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
Expand Down Expand Up @@ -654,9 +679,11 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
startTime: Long,
env: SparkEnv,
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
new ReaderIterator(
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {

protected override def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
envVars.getOrElse("PYTHONPATH", ""),
sys.env.getOrElse("PYTHONPATH", ""))

def create(): Socket = {
def create(): (Socket, Option[Int]) = {
if (useDaemon) {
self.synchronized {
if (idleWorkers.nonEmpty) {
return idleWorkers.dequeue()
val worker = idleWorkers.dequeue()
return (worker, daemonWorkers.get(worker))
}
}
createThroughDaemon()
Expand All @@ -113,9 +114,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
* processes itself to avoid the high cost of forking from Java. This currently only works
* on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {
private def createThroughDaemon(): (Socket, Option[Int]) = {

def createSocket(): Socket = {
def createSocket(): (Socket, Option[Int]) = {
val socket = new Socket(daemonHost, daemonPort)
val pid = new DataInputStream(socket.getInputStream).readInt()
if (pid < 0) {
Expand All @@ -124,7 +125,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

authHelper.authToServer(socket)
daemonWorkers.put(socket, pid)
socket
(socket, Some(pid))
}

self.synchronized {
Expand All @@ -148,7 +149,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
/**
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
*/
private def createSimpleWorker(): Socket = {
private def createSimpleWorker(): (Socket, Option[Int]) = {
var serverSocket: ServerSocket = null
try {
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
Expand All @@ -173,10 +174,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
try {
val socket = serverSocket.accept()
authHelper.authClient(socket)
// TODO: When we drop JDK 8, we can just use worker.pid()
val pid = new DataInputStream(socket.getInputStream).readInt()
if (pid < 0) {
throw new IllegalStateException("Python failed to launch worker with code " + pid)
}
self.synchronized {
simpleWorkers.put(socket, worker)
}
return socket
return (socket, Some(pid))
} catch {
case e: Exception =>
throw new SparkException("Python worker failed to connect back.", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,12 @@ private[spark] object Python {
.version("3.1.0")
.timeConf(TimeUnit.SECONDS)
.createWithDefaultString("15s")

val PYTHON_WORKER_FAULTHANLDER_ENABLED = ConfigBuilder("spark.python.worker.faulthandler.enabled")
.doc("When true, Python workers set up the faulthandler for the case when the Python worker " +
"exits unexpectedly (crashes), and shows the stack trace of the moment the Python worker " +
"crashes in the error message if captured successfully.")
.version("3.2.0")
.booleanConf
.createWithDefault(false)
}
29 changes: 29 additions & 0 deletions python/pyspark/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,35 @@ def getrlimit():
def tearDown(self):
self.sc.stop()


class WorkerSegfaultTest(ReusedPySparkTestCase):

@classmethod
def conf(cls):
_conf = super(WorkerSegfaultTest, cls).conf()
_conf.set("spark.python.worker.faulthandler.enabled", "true")
return _conf

def test_python_segfault(self):
try:
def f():
import ctypes
ctypes.string_at(0)

self.sc.parallelize([1]).map(lambda x: f()).count()
except Py4JJavaError as e:
self.assertRegex(str(e), "Segmentation fault")


class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest):

@classmethod
def conf(cls):
_conf = super(WorkerSegfaultNonDaemonTest, cls).conf()
_conf.set("spark.python.use.daemon", "false")
return _conf


if __name__ == "__main__":
import unittest
from pyspark.tests.test_worker import * # noqa: F401
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
has_resource_module = False
import traceback
import warnings
import faulthandler

from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
Expand Down Expand Up @@ -463,7 +464,13 @@ def mapper(a):


def main(infile, outfile):
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
if faulthandler_log_path:
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
faulthandler_log_file = open(faulthandler_log_path, "w")
faulthandler.enable(file=faulthandler_log_file)

boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
Expand Down Expand Up @@ -636,6 +643,11 @@ def process():
print("PySpark worker failed with exception:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
sys.exit(-1)
finally:
if faulthandler_log_path:
faulthandler.disable()
faulthandler_log_file.close()
os.remove(faulthandler_log_path)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
write_long(shuffle.MemoryBytesSpilled, outfile)
Expand All @@ -661,4 +673,7 @@ def process():
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
# TODO: Remove thw following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
startTime: Long,
env: SparkEnv,
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[ColumnarBatch] = {

new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
new ReaderIterator(
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {

private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdin reader for $pythonExec", 0, Long.MaxValue)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ class PythonUDFRunner(
startTime: Long,
env: SparkEnv,
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
new ReaderIterator(
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {

protected override def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
Expand Down

0 comments on commit 115b8a1

Please sign in to comment.