Skip to content

Commit

Permalink
[SPARK-44705][PYTHON] Make PythonRunner single-threaded
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
PythonRunner, a utility that executes Python UDFs in Spark, uses two threads in a producer-consumer model today. This multi-threading model is problematic and confusing as Spark's execution model within a task is commonly understood to be single-threaded.
More importantly, this departure of a double-threaded execution resulted in a series of customer issues involving [race conditions](https://issues.apache.org/jira/browse/SPARK-33277) and [deadlocks](https://issues.apache.org/jira/browse/SPARK-38677) between threads as the code was hard to reason about. There have been multiple attempts to reign in these issues, viz., [fix 1](https://issues.apache.org/jira/browse/SPARK-22535), [fix 2](#30177), [fix 3](243c321). Moreover, the fixes have made the code base somewhat abstruse by introducing multiple daemon [monitor threads](https://github.com/apache/spark/blob/a3a32912be04d3760cb34eb4b79d6d481bbec502/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala#L579) to detect deadlocks. This PR makes PythonRunner single-threaded making it easier to reason about and improving code health.

#### Current Execution Model in Spark for Python UDFs
For queries containing Python UDFs, the main Java task thread spins up a new writer thread to pipe data from the child Spark plan into the Python worker evaluating the UDF. The writer thread runs in a tight loop: evaluates the child Spark plan, and feeds the resulting output to the Python worker. The main task thread simultaneously consumes the Python UDF’s output and evaluates the parent Spark plan to produce the final result.
The I/O to/from the Python worker uses blocking Java Sockets necessitating the use of two threads, one responsible for input to the Python worker and the other for output. Without two threads, it is easy to run into a deadlock. For example, the task can block forever waiting for the output from the Python worker. The output will never arrive until the input is supplied to the Python worker, which is not possible as the task thread is blocked while waiting on output.

#### Proposed Fix

The proposed fix is to move to the standard single-threaded execution model within a task, i.e., to do away with the writer thread. In addition to mitigating the crashes, the fix reduces the complexity of the existing code by doing away with many safety checks in place to track deadlocks in the double-threaded execution model.

In the new model, the main task thread alternates between consuming/feeding data to the Python worker using asynchronous I/O through Java’s [SocketChannel](https://docs.oracle.com/javase/7/docs/api/java/nio/channels/SocketChannel.html). See the `read()` method in the code below for approximately how this is achieved.

```
case class PythonUDFRunner {

  private var nextRow: Row = _
  private var endOfStream = false
  private var childHasNext = true
  private var buffer: ByteBuffer = _

  def hasNext(): Boolean = nextRow != null || {
     if (!endOfStream) {
       read(buffer)
       nextRow = deserialize(buffer)
       hasNext
     } else {
       false
     }
  }

  def next(): Row = {
     if (hasNext) {
       val outputRow = nextRow
       nextRow = null
       outputRow
     } else {
       null
     }
  }

  def read(buf: Array[Byte]): Row = {
    var n = 0
    while (n == 0) {
    // Alternate between reading/writing to the Python worker using async I/O
    if (pythonWorker.isReadable) {
      n = pythonWorker.read(buf)
    }
    if (pythonWorker.isWritable) {
      consumeChildPlanAndWriteDataToPythonWorker()
    }
  }

  def consumeChildPlanAndWriteDataToPythonWorker(): Unit = {
      // Tracks whether the connection to the Python worker can be written to.
      var socketAcceptsInput = true
      while (socketAcceptsInput && (childHasNext || buffer.hasRemaining)) {
        if (!buffer.hasRemaining && childHasNext) {
          // Consume data from the child and buffer it.
          writeToBuffer(childPlan.next(), buffer)
          childHasNext = childPlan.hasNext()
          if (!childHasNext) {
            // Exhausted child plan’s output. Write a keyword to the Python worker signaling the end of data input.
            writeToBuffer(endOfStream)
          }
        }
        // Try to write as much buffered data as possible to the Python worker.
        while (buffer.hasRemaining && socketAcceptsInput) {
          val n = writeToPythonWorker(buffer)
          // `writeToPythonWorker()` returns 0 when the socket cannot accept more data right now.
          socketAcceptsInput = n > 0
        }
      }
    }
}

```
### Why are the changes needed?
This PR makes PythonRunner single-threaded making it easier to reason about and improving code health.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing tests.

Closes #42385 from utkarsh39/SPARK-44705.

Authored-by: Utkarsh <utkarsh.agarwal@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
utkarsh39 authored and HyukjinKwon committed Aug 11, 2023
1 parent 9bde882 commit 8aaff55
Show file tree
Hide file tree
Showing 21 changed files with 666 additions and 306 deletions.
Expand Up @@ -30,8 +30,10 @@ import org.apache.spark.annotation.DeveloperApi
* Thus, we should use [[ContextAwareIterator]] to stop consuming after the task ends.
*
* @since 3.1.0
* @deprecated since 4.0.0 as its only usage for Python evaluation is now extinct
*/
@DeveloperApi
@deprecated("Only usage for Python evaluation is now extinct", "3.5.0")
class ContextAwareIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {

Expand Down
15 changes: 7 additions & 8 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark

import java.io.File
import java.net.Socket
import java.util.Locale

import scala.collection.JavaConverters._
Expand All @@ -30,7 +29,7 @@ import com.google.common.cache.CacheBuilder
import org.apache.hadoop.conf.Configuration

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.api.python.{PythonWorker, PythonWorkerFactory}
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.executor.ExecutorBackend
import org.apache.spark.internal.{config, Logging}
Expand Down Expand Up @@ -129,7 +128,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
daemonModule: String,
envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
synchronized {
val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars)
pythonWorkers.getOrElseUpdate(key,
Expand All @@ -140,7 +139,7 @@ class SparkEnv (
private[spark] def createPythonWorker(
pythonExec: String,
workerModule: String,
envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
createPythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars)
}
Expand All @@ -150,7 +149,7 @@ class SparkEnv (
workerModule: String,
daemonModule: String,
envVars: Map[String, String],
worker: Socket): Unit = {
worker: PythonWorker): Unit = {
synchronized {
val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars)
pythonWorkers.get(key).foreach(_.stopWorker(worker))
Expand All @@ -161,7 +160,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
envVars: Map[String, String],
worker: Socket): Unit = {
worker: PythonWorker): Unit = {
destroyPythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker)
}
Expand All @@ -171,7 +170,7 @@ class SparkEnv (
workerModule: String,
daemonModule: String,
envVars: Map[String, String],
worker: Socket): Unit = {
worker: PythonWorker): Unit = {
synchronized {
val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars)
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
Expand All @@ -182,7 +181,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
envVars: Map[String, String],
worker: Socket): Unit = {
worker: PythonWorker): Unit = {
releasePythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker)
}
Expand Down
22 changes: 18 additions & 4 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Expand Up @@ -137,15 +137,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte]
private[spark] object PythonRDD extends Logging {

// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
private val workerBroadcasts = new mutable.WeakHashMap[PythonWorker, mutable.Set[Long]]()

// Authentication helper used when serving iterator data.
private lazy val authHelper = {
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
new SocketAuthHelper(conf)
}

def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
def getWorkerBroadcasts(worker: PythonWorker): mutable.Set[Long] = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
}
Expand Down Expand Up @@ -300,7 +300,11 @@ private[spark] object PythonRDD extends Logging {
new PythonBroadcast(path)
}

def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Unit = {
/**
* Writes the next element of the iterator `iter` to `dataOut`. Returns true if any data was
* written to the stream. Returns false if no data was written as the iterator has been exhausted.
*/
def writeNextElementToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Boolean = {

def write(obj: Any): Unit = obj match {
case null =>
Expand All @@ -318,8 +322,18 @@ private[spark] object PythonRDD extends Logging {
case other =>
throw new SparkException("Unexpected element type " + other.getClass)
}
if (iter.hasNext) {
write(iter.next())
true
} else {
false
}
}

iter.foreach(write)
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Unit = {
while (writeNextElementToStream(iter, dataOut)) {
// Nothing.
}
}

/**
Expand Down

0 comments on commit 8aaff55

Please sign in to comment.