Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9026] Refactor SimpleFutureAction.onComplete to not launch separate thread for every callback #7385

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 16 additions & 37 deletions core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit

import org.apache.spark.api.java.JavaFutureAction
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
import org.apache.spark.scheduler.JobWaiter

import scala.concurrent._
import scala.concurrent.duration.Duration
Expand Down Expand Up @@ -108,63 +108,42 @@ trait FutureAction[T] extends Future[T] {
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {

// Note: `resultFunc` is a closure which may contain references to state that's updated by the
// JobWaiter's result handler function. It should only be evaluated once the job has succeeded.

@volatile private var _cancelled: Boolean = false
private[this] lazy val resultFuncOutput: T = {
assert(isCompleted, "resultFunc should only be evaluated after the job has completed")
resultFunc
}

override def cancel() {
_cancelled = true
jobWaiter.cancel()
}

override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
if (!atMost.isFinite()) {
awaitResult()
} else jobWaiter.synchronized {
val finishTime = System.currentTimeMillis() + atMost.toMillis
while (!isCompleted) {
val time = System.currentTimeMillis()
if (time >= finishTime) {
throw new TimeoutException
} else {
jobWaiter.wait(finishTime - time)
}
}
}
jobWaiter.ready(atMost)(permit)
this
}

@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
ready(atMost)(permit)
awaitResult() match {
case scala.util.Success(res) => res
case scala.util.Failure(e) => throw e
}
jobWaiter.result(atMost)(permit) // Throws exception if the job failed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what thread does this use? does it use some implicit thread pool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see... jobWaiter.result delegates to the job waiter's promise.future.result.

Here, I think we have an instance of DefaultPromise, whose result (https://github.com/scala/scala/blob/v2.10.4/src/library/scala/concurrent/impl/Promise.scala#L222) is implemented in terms of ready (https://github.com/scala/scala/blob/v2.10.4/src/library/scala/concurrent/impl/Promise.scala#L218), which in turn is implemented using the internal tryAwait method: https://github.com/scala/scala/blob/v2.10.4/src/library/scala/concurrent/impl/Promise.scala#L194

It looks like this is implemented by scheduling an onComplete callback which updates a latch. This callback runs on Future's InternalCallbackExecutor: https://github.com/scala/scala/blob/v2.10.4/src/library/scala/concurrent/Future.scala#L590

resultFuncOutput // This function is safe to evaluate because the job must have succeeded.
}

override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
executor.execute(new Runnable {
override def run() {
func(awaitResult())
}
})
override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
jobWaiter.map { _ => resultFuncOutput }.onComplete(func)
}

override def isCompleted: Boolean = jobWaiter.jobFinished
override def isCompleted: Boolean = jobWaiter.isCompleted

override def isCancelled: Boolean = _cancelled

override def value: Option[Try[T]] = {
if (jobWaiter.jobFinished) {
Some(awaitResult())
} else {
None
}
}

private def awaitResult(): Try[T] = {
jobWaiter.awaitResult() match {
case JobSucceeded => scala.util.Success(resultFunc)
case JobFailed(e: Exception) => scala.util.Failure(e)
jobWaiter.value.map { valueTry =>
valueTry.map(_ => resultFuncOutput)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger

import scala.collection.Map
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.existentials
import scala.language.postfixOps
Expand Down Expand Up @@ -543,11 +544,11 @@ class DAGScheduler(
properties: Properties): Unit = {
val start = System.nanoTime
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded =>
Await.ready(waiter, Duration.Inf).value.get match {
case scala.util.Success(_) =>
logInfo("Job %d finished: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
case JobFailed(exception: Exception) =>
case scala.util.Failure(exception: Throwable) =>
logInfo("Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
// SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
Expand Down
61 changes: 39 additions & 22 deletions core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.scheduler

import scala.concurrent.duration.Duration
import scala.concurrent._
import scala.util.Try

/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
Expand All @@ -26,19 +30,37 @@ private[spark] class JobWaiter[T](
val jobId: Int,
totalTasks: Int,
resultHandler: (Int, T) => Unit)
extends JobListener {
extends JobListener with Future[Unit] {

private[this] val promise: Promise[Unit] = {
if (totalTasks == 0) {
Promise.successful[Unit]()
} else {
Promise[Unit]()
}
}
private[this] val promiseFuture: Future[Unit] = promise.future
private[this] var finishedTasks = 0

override def onComplete[U](func: (Try[Unit]) => U)(implicit executor: ExecutionContext): Unit = {
promiseFuture.onComplete(func)
}

private var finishedTasks = 0
override def isCompleted: Boolean = promiseFuture.isCompleted

// Is the job as a whole finished (succeeded or failed)?
@volatile
private var _jobFinished = totalTasks == 0
override def value: Option[Try[Unit]] = promiseFuture.value

def jobFinished: Boolean = _jobFinished
@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): Unit = {
promiseFuture.result(atMost)(permit)
}

// If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
// partition RDDs), we set the jobResult directly to JobSucceeded.
private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
@throws(classOf[InterruptedException])
@throws(classOf[TimeoutException])
override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
promiseFuture.ready(atMost)(permit)
this
}

/**
* Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
Expand All @@ -50,28 +72,23 @@ private[spark] class JobWaiter[T](
}

override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
if (_jobFinished) {
if (isCompleted) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
_jobFinished = true
jobResult = JobSucceeded
this.notifyAll()
promise.success()
}
}

override def jobFailed(exception: Exception): Unit = synchronized {
_jobFinished = true
jobResult = JobFailed(exception)
this.notifyAll()
}

def awaitResult(): JobResult = synchronized {
while (!_jobFinished) {
this.wait()
// There are certain situations where jobFailed can be called multiple times for the same
// job. We guard against this by making this method idempotent.
if (!isCompleted) {
promise.failure(exception)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks tryFailure would be simpler.

} else {
assert(promiseFuture.value.get.isFailure)
}
return jobResult
}
}