Skip to content

Commit

Permalink
[SPARK-9026][SPARK-4514] Modifications to JobWaiter, FutureAction, an…
Browse files Browse the repository at this point in the history
…d AsyncRDDActions to support non-blocking operation

These changes rework the implementations of `SimpleFutureAction`, `ComplexFutureAction`, `JobWaiter`, and `AsyncRDDActions` such that asynchronous callbacks on the generated `Futures` NEVER block waiting for a job to complete. A small amount of mutex synchronization is necessary to protect the internal fields that manage cancellation, but these locks are only held very briefly and in practice should almost never cause any blocking to occur. The existing blocking APIs of these classes are retained, but they simply delegate to the underlying non-blocking API and `Await` the results with indefinite timeouts.

Associated JIRA ticket: https://issues.apache.org/jira/browse/SPARK-9026
Also fixes: https://issues.apache.org/jira/browse/SPARK-4514

This pull request contains all my own original work, which I release to the Spark project under its open source license.

Author: Richard W. Eggert II <richard.eggert@gmail.com>

Closes #9264 from reggert/fix-futureaction.
  • Loading branch information
reggert authored and Andrew Or committed Dec 16, 2015
1 parent a63d9ed commit 765a488
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 158 deletions.
164 changes: 61 additions & 103 deletions core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ package org.apache.spark
import java.util.Collections
import java.util.concurrent.TimeUnit

import scala.concurrent._
import scala.concurrent.duration.Duration
import scala.util.Try

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaFutureAction
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
import org.apache.spark.scheduler.JobWaiter

import scala.concurrent._
import scala.concurrent.duration.Duration
import scala.util.{Failure, Try}

/**
* A future for the result of an action to support cancellation. This is an extension of the
Expand Down Expand Up @@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] {
* A [[FutureAction]] holding the result of an action that triggers a single job. Examples include
* count, collect, reduce.
*/
@DeveloperApi
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {

Expand All @@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}

override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
if (!atMost.isFinite()) {
awaitResult()
} else jobWaiter.synchronized {
val finishTime = System.currentTimeMillis() + atMost.toMillis
while (!isCompleted) {
val time = System.currentTimeMillis()
if (time >= finishTime) {
throw new TimeoutException
} else {
jobWaiter.wait(finishTime - time)
}
}
}
jobWaiter.completionFuture.ready(atMost)
this
}

@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
ready(atMost)(permit)
awaitResult() match {
case scala.util.Success(res) => res
case scala.util.Failure(e) => throw e
}
jobWaiter.completionFuture.ready(atMost)
assert(value.isDefined, "Future has not completed properly")
value.get.get
}

override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
executor.execute(new Runnable {
override def run() {
func(awaitResult())
}
})
jobWaiter.completionFuture onComplete {_ => func(value.get)}
}

override def isCompleted: Boolean = jobWaiter.jobFinished

override def isCancelled: Boolean = _cancelled

override def value: Option[Try[T]] = {
if (jobWaiter.jobFinished) {
Some(awaitResult())
} else {
None
}
}

private def awaitResult(): Try[T] = {
jobWaiter.awaitResult() match {
case JobSucceeded => scala.util.Success(resultFunc)
case JobFailed(e: Exception) => scala.util.Failure(e)
}
}
override def value: Option[Try[T]] =
jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)}

def jobIds: Seq[Int] = Seq(jobWaiter.jobId)
}


/**
* Handle via which a "run" function passed to a [[ComplexFutureAction]]
* can submit jobs for execution.
*/
@DeveloperApi
trait JobSubmitter {
/**
* Submit a job for execution and return a FutureAction holding the result.
* This is a wrapper around the same functionality provided by SparkContext
* to enable cancellation.
*/
def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
resultHandler: (Int, U) => Unit,
resultFunc: => R): FutureAction[R]
}


/**
* A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
* takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
* action thread if it is being blocked by a job.
* takeSample. Cancellation works by setting the cancelled flag to true and cancelling any pending
* jobs.
*/
class ComplexFutureAction[T] extends FutureAction[T] {
@DeveloperApi
class ComplexFutureAction[T](run : JobSubmitter => Future[T])
extends FutureAction[T] { self =>

// Pointer to the thread that is executing the action. It is set when the action is run.
@volatile private var thread: Thread = _
@volatile private var _cancelled = false

// A flag indicating whether the future has been cancelled. This is used in case the future
// is cancelled before the action was even run (and thus we have no thread to interrupt).
@volatile private var _cancelled: Boolean = false

@volatile private var jobs: Seq[Int] = Nil
@volatile private var subActions: List[FutureAction[_]] = Nil

// A promise used to signal the future.
private val p = promise[T]()
private val p = Promise[T]().tryCompleteWith(run(jobSubmitter))

override def cancel(): Unit = this.synchronized {
override def cancel(): Unit = synchronized {
_cancelled = true
if (thread != null) {
thread.interrupt()
}
}

/**
* Executes some action enclosed in the closure. To properly enable cancellation, the closure
* should use runJob implementation in this promise. See takeAsync for example.
*/
def run(func: => T)(implicit executor: ExecutionContext): this.type = {
scala.concurrent.future {
thread = Thread.currentThread
try {
p.success(func)
} catch {
case e: Exception => p.failure(e)
} finally {
// This lock guarantees when calling `thread.interrupt()` in `cancel`,
// thread won't be set to null.
ComplexFutureAction.this.synchronized {
thread = null
}
}
}
this
p.tryFailure(new SparkException("Action has been cancelled"))
subActions.foreach(_.cancel())
}

/**
* Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
* to enable cancellation.
*/
def runJob[T, U, R](
private def jobSubmitter = new JobSubmitter {
def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
resultHandler: (Int, U) => Unit,
resultFunc: => R) {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
val job = this.synchronized {
resultFunc: => R): FutureAction[R] = self.synchronized {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
if (!isCancelled) {
rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
val job = rdd.context.submitJob(
rdd,
processPartition,
partitions,
resultHandler,
resultFunc)
subActions = job :: subActions
job
} else {
throw new SparkException("Action has been cancelled")
}
}

this.jobs = jobs ++ job.jobIds

// Wait for the job to complete. If the action is cancelled (with an interrupt),
// cancel the job and stop the execution. This is not in a synchronized block because
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
try {
Await.ready(job, Duration.Inf)
} catch {
case e: InterruptedException =>
job.cancel()
throw new SparkException("Action has been cancelled")
}
}

override def isCancelled: Boolean = _cancelled
Expand All @@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] {

override def value: Option[Try[T]] = p.future.value

def jobIds: Seq[Int] = jobs
def jobIds: Seq[Int] = subActions.flatMap(_.jobIds)

}


private[spark]
class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T)
extends JavaFutureAction[T] {
Expand All @@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
Await.ready(futureAction, timeout)
futureAction.value.get match {
case scala.util.Success(value) => converter(value)
case Failure(exception) =>
case scala.util.Failure(exception) =>
if (isCancelled) {
throw new CancellationException("Job cancelled").initCause(exception)
} else {
Expand Down
48 changes: 27 additions & 21 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ package org.apache.spark.rdd

import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.util.ThreadUtils

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext
import scala.concurrent.{Future, ExecutionContext}
import scala.reflect.ClassTag

import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging}
import org.apache.spark.util.ThreadUtils

/**
* A set of asynchronous RDD actions available through an implicit conversion.
Expand Down Expand Up @@ -65,17 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
* Returns a future for retrieving the first num elements of the RDD.
*/
def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
val f = new ComplexFutureAction[Seq[T]]
val callSite = self.context.getCallSite

f.run {
// This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which
// is a cached thread pool.
val results = new ArrayBuffer[T](num)
val totalParts = self.partitions.length
var partsScanned = 0
self.context.setCallSite(callSite)
while (results.size < num && partsScanned < totalParts) {
val localProperties = self.context.getLocalProperties
// Cached thread pool to handle aggregation of subtasks.
implicit val executionContext = AsyncRDDActions.futureExecutionContext
val results = new ArrayBuffer[T](num)
val totalParts = self.partitions.length

/*
Recursively triggers jobs to scan partitions until either the requested
number of elements are retrieved, or the partitions to scan are exhausted.
This implementation is non-blocking, asynchronously handling the
results of each job and triggering the next job using callbacks on futures.
*/
def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
if (results.size >= num || partsScanned >= totalParts) {
Future.successful(results.toSeq)
} else {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
Expand All @@ -97,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)

val buf = new Array[Array[T]](p.size)
f.runJob(self,
self.context.setCallSite(callSite)
self.context.setLocalProperties(localProperties)
val job = jobSubmitter.submitJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
(index: Int, data: Array[T]) => buf(index) = data,
Unit)

buf.foreach(results ++= _.take(num - results.size))
partsScanned += numPartsToTry
job.flatMap {_ =>
buf.foreach(results ++= _.take(num - results.size))
continue(partsScanned + numPartsToTry)
}
}
results.toSeq
}(AsyncRDDActions.futureExecutionContext)

f
new ComplexFutureAction[Seq[T]](continue(0)(_))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger

import scala.collection.Map
import scala.collection.mutable.{HashMap, HashSet, Stack}
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.existentials
import scala.language.postfixOps
Expand Down Expand Up @@ -610,11 +611,12 @@ class DAGScheduler(
properties: Properties): Unit = {
val start = System.nanoTime
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded =>
Await.ready(waiter.completionFuture, atMost = Duration.Inf)
waiter.completionFuture.value.get match {
case scala.util.Success(_) =>
logInfo("Job %d finished: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
case JobFailed(exception: Exception) =>
case scala.util.Failure(exception) =>
logInfo("Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
// SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
Expand Down
Loading

0 comments on commit 765a488

Please sign in to comment.