diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 48792a958130c..4799aadb49d9b 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobWaiter import scala.concurrent._ import scala.concurrent.duration.Duration @@ -108,7 +108,14 @@ trait FutureAction[T] extends Future[T] { class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { + // Note: `resultFunc` is a closure which may contain references to state that's updated by the + // JobWaiter's result handler function. It should only be evaluated once the job has succeeded. + @volatile private var _cancelled: Boolean = false + private[this] lazy val resultFuncOutput: T = { + assert(isCompleted, "resultFunc should only be evaluated after the job has completed") + resultFunc + } override def cancel() { _cancelled = true @@ -116,55 +123,27 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { - if (!atMost.isFinite()) { - awaitResult() - } else jobWaiter.synchronized { - val finishTime = System.currentTimeMillis() + atMost.toMillis - while (!isCompleted) { - val time = System.currentTimeMillis() - if (time >= finishTime) { - throw new TimeoutException - } else { - jobWaiter.wait(finishTime - time) - } - } - } + jobWaiter.ready(atMost)(permit) this } @throws(classOf[Exception]) override def result(atMost: Duration)(implicit permit: CanAwait): T = { - ready(atMost)(permit) - awaitResult() match { - case scala.util.Success(res) => res - case scala.util.Failure(e) => throw e - } + jobWaiter.result(atMost)(permit) // Throws exception if the job failed. + resultFuncOutput // This function is safe to evaluate because the job must have succeeded. } - override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { - executor.execute(new Runnable { - override def run() { - func(awaitResult()) - } - }) + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = { + jobWaiter.map { _ => resultFuncOutput }.onComplete(func) } - override def isCompleted: Boolean = jobWaiter.jobFinished + override def isCompleted: Boolean = jobWaiter.isCompleted override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { - if (jobWaiter.jobFinished) { - Some(awaitResult()) - } else { - None - } - } - - private def awaitResult(): Try[T] = { - jobWaiter.awaitResult() match { - case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception) => scala.util.Failure(e) + jobWaiter.value.map { valueTry => + valueTry.map(_ => resultFuncOutput) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 552dabcfa5139..d5d21f03a77a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack} +import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -543,11 +544,11 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - waiter.awaitResult() match { - case JobSucceeded => + Await.ready(waiter, Duration.Inf).value.get match { + case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - case JobFailed(exception: Exception) => + case scala.util.Failure(exception: Throwable) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 382b09422a4a0..586f9eff5a8ed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,6 +17,10 @@ package org.apache.spark.scheduler +import scala.concurrent.duration.Duration +import scala.concurrent._ +import scala.util.Try + /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. @@ -26,19 +30,37 @@ private[spark] class JobWaiter[T]( val jobId: Int, totalTasks: Int, resultHandler: (Int, T) => Unit) - extends JobListener { + extends JobListener with Future[Unit] { + + private[this] val promise: Promise[Unit] = { + if (totalTasks == 0) { + Promise.successful[Unit]() + } else { + Promise[Unit]() + } + } + private[this] val promiseFuture: Future[Unit] = promise.future + private[this] var finishedTasks = 0 + + override def onComplete[U](func: (Try[Unit]) => U)(implicit executor: ExecutionContext): Unit = { + promiseFuture.onComplete(func) + } - private var finishedTasks = 0 + override def isCompleted: Boolean = promiseFuture.isCompleted - // Is the job as a whole finished (succeeded or failed)? - @volatile - private var _jobFinished = totalTasks == 0 + override def value: Option[Try[Unit]] = promiseFuture.value - def jobFinished: Boolean = _jobFinished + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): Unit = { + promiseFuture.result(atMost)(permit) + } - // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero - // partition RDDs), we set the jobResult directly to JobSucceeded. - private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + @throws(classOf[InterruptedException]) + @throws(classOf[TimeoutException]) + override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = { + promiseFuture.ready(atMost)(permit) + this + } /** * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled @@ -50,28 +72,23 @@ private[spark] class JobWaiter[T]( } override def taskSucceeded(index: Int, result: Any): Unit = synchronized { - if (_jobFinished) { + if (isCompleted) { throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") } resultHandler(index, result.asInstanceOf[T]) finishedTasks += 1 if (finishedTasks == totalTasks) { - _jobFinished = true - jobResult = JobSucceeded - this.notifyAll() + promise.success() } } override def jobFailed(exception: Exception): Unit = synchronized { - _jobFinished = true - jobResult = JobFailed(exception) - this.notifyAll() - } - - def awaitResult(): JobResult = synchronized { - while (!_jobFinished) { - this.wait() + // There are certain situations where jobFailed can be called multiple times for the same + // job. We guard against this by making this method idempotent. + if (!isCompleted) { + promise.failure(exception) + } else { + assert(promiseFuture.value.get.isFailure) } - return jobResult } }