Skip to content

Commit

Permalink
[SPARK-30667][CORE] Add allGather method to BarrierTaskContext
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call.

### Why are the changes needed?

There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on.

### Does this PR introduce any user-facing change?

Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs.

### How was this patch tested?

Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID.

An example through the Python API:
```python
>>> from pyspark import BarrierTaskContext
>>>
>>> def f(iterator):
...     context = BarrierTaskContext.get()
...     return [context.allGather('{}'.format(context.partitionId()))]
...
>>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0]
[u'3', u'1', u'0', u'2']
```

Closes #27395 from sarthfrey/master.

Lead-authored-by: sarthfrey-db <sarth.frey@databricks.com>
Co-authored-by: sarthfrey <sarth.frey@gmail.com>
Signed-off-by: Xiangrui Meng <meng@databricks.com>
(cherry picked from commit 57254c9)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
  • Loading branch information
2 people authored and mengxr committed Feb 19, 2020
1 parent c92d437 commit af63971
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 79 deletions.
113 changes: 99 additions & 14 deletions core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
Expand Up @@ -17,12 +17,17 @@

package org.apache.spark

import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Consumer

import scala.collection.mutable.ArrayBuffer

import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
Expand Down Expand Up @@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator(
// reset when a barrier() call fails due to timeout.
private var barrierEpoch: Int = 0

// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
// call.
// An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)

// An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer]

// The blocking requestMethod called by tasks to sync up for this stage attempt
private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER

// A timer task that ensures we may timeout for a barrier() call.
private var timerTask: TimerTask = null

Expand Down Expand Up @@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator(

// Process the global sync request. The barrier() call succeed if collected enough requests
// within a configured time, otherwise fail all the pending requests.
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
def handleRequest(
requester: RpcCallContext,
request: RequestToSync
): Unit = synchronized {
val taskId = request.taskAttemptId
val epoch = request.barrierEpoch
val requestMethod = request.requestMethod
val partitionId = request.partitionId
val allGatherMessage = request match {
case ag: AllGatherRequestToSync => ag.allGatherMessage
case _ => ""
}

if (requesters.size == 0) {
requestMethodToSync = requestMethod
}

if (requestMethodToSync != requestMethod) {
requesters.foreach(
_.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " +
s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " +
s"the current synchronized requestMethod `$requestMethodToSync`"
))
)
cleanupBarrierStage(barrierId)
}

// Require the number of tasks is correctly set from the BarrierTaskContext.
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
Expand All @@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator(
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
allGatherMessages(partitionId) = allGatherMessage
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (maybeFinishAllRequesters(requesters, numTasks)) {
Expand All @@ -162,6 +196,7 @@ private[spark] class BarrierCoordinator(
s"tasks, finished successfully.")
barrierEpoch += 1
requesters.clear()
allGatherMessages.clear()
cancelTimerTask()
}
}
Expand All @@ -173,7 +208,13 @@ private[spark] class BarrierCoordinator(
requesters: ArrayBuffer[RpcCallContext],
numTasks: Int): Boolean = {
if (requesters.size == numTasks) {
requesters.foreach(_.reply(()))
requestMethodToSync match {
case RequestMethod.BARRIER =>
requesters.foreach(_.reply(""))
case RequestMethod.ALL_GATHER =>
val json: String = compact(render(allGatherMessages))
requesters.foreach(_.reply(json))
}
true
} else {
false
Expand All @@ -186,6 +227,7 @@ private[spark] class BarrierCoordinator(
// messages come from current stage attempt shall fail.
barrierEpoch = -1
requesters.clear()
allGatherMessages.clear()
cancelTimerTask()
}
}
Expand All @@ -199,11 +241,11 @@ private[spark] class BarrierCoordinator(
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
case request: RequestToSync =>
// Get or init the ContextBarrierState correspond to the stage attempt.
val barrierId = ContextBarrierId(stageId, stageAttemptId)
val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
states.computeIfAbsent(barrierId,
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
(key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks))
val barrierState = states.get(barrierId)

barrierState.handleRequest(context, request)
Expand All @@ -216,6 +258,16 @@ private[spark] class BarrierCoordinator(

private[spark] sealed trait BarrierCoordinatorMessage extends Serializable

private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
def numTasks: Int
def stageId: Int
def stageAttemptId: Int
def taskAttemptId: Long
def barrierEpoch: Int
def partitionId: Int
def requestMethod: RequestMethod.Value
}

/**
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
Expand All @@ -224,11 +276,44 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
*/
private[spark] case class RequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int) extends BarrierCoordinatorMessage
private[spark] case class BarrierRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value
) extends RequestToSync

/**
* A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
*
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
* @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
*/
private[spark] case class AllGatherRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
) extends RequestToSync

private[spark] object RequestMethod extends Enumeration {
val BARRIER, ALL_GATHER = Value
}
153 changes: 106 additions & 47 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Expand Up @@ -17,11 +17,19 @@

package org.apache.spark

import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Properties, Timer, TimerTask}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.language.postfixOps

import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.parse

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
Expand Down Expand Up @@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] (
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size

/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
private def getRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptNumber: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
): RequestToSync = {
requestMethod match {
case RequestMethod.BARRIER =>
BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod)
case RequestMethod.ALL_GATHER =>
AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod, allGatherMessage)
}
}

private def runBarrier(
requestMethod: RequestMethod.Value,
allGatherMessage: String = ""
): String = {

logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
Expand All @@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] (
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)

var json: String = ""

try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Unit](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch),
val abortableRpcFuture = barrierCoordinator.askAbortable[String](
message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
Expand All @@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] (
while (!abortableRpcFuture.toFuture.isCompleted) {
// wait RPC future for at most 1 second
try {
ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
} catch {
case _: TimeoutException | _: InterruptedException =>
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
Expand Down Expand Up @@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] (
timerTask.cancel()
timer.purge()
}
json
}

/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
runBarrier(RequestMethod.BARRIER)
()
}

/**
* :: Experimental ::
* Blocks until all tasks in the same stage have reached this routine. Each task passes in
* a message and returns with a list of all the messages passed in by each of those tasks.
*
* CAUTION! The allGather method requires the same precautions as the barrier method
*
* The message is type String rather than Array[Byte] because it is more convenient for
* the user at the cost of worse performance.
*/
@Experimental
@Since("3.0.0")
def allGather(message: String): ArrayBuffer[String] = {
val json = runBarrier(RequestMethod.ALL_GATHER, message)
val jsonArray = parse(json)
implicit val formats = DefaultFormats
ArrayBuffer(jsonArray.extract[Array[String]]: _*)
}

/**
Expand Down

0 comments on commit af63971

Please sign in to comment.