diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 48792a958130c..2a8220ff40090 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -20,13 +20,15 @@ package org.apache.spark import java.util.Collections import java.util.concurrent.TimeUnit +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobWaiter -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.{Failure, Try} /** * A future for the result of an action to support cancellation. This is an extension of the @@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] { * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ +@DeveloperApi class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { @@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { - if (!atMost.isFinite()) { - awaitResult() - } else jobWaiter.synchronized { - val finishTime = System.currentTimeMillis() + atMost.toMillis - while (!isCompleted) { - val time = System.currentTimeMillis() - if (time >= finishTime) { - throw new TimeoutException - } else { - jobWaiter.wait(finishTime - time) - } - } - } + jobWaiter.completionFuture.ready(atMost) this } @throws(classOf[Exception]) override def result(atMost: Duration)(implicit permit: CanAwait): T = { - ready(atMost)(permit) - awaitResult() match { - case scala.util.Success(res) => res - case scala.util.Failure(e) => throw e - } + jobWaiter.completionFuture.ready(atMost) + assert(value.isDefined, "Future has not completed properly") + value.get.get } override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { - executor.execute(new Runnable { - override def run() { - func(awaitResult()) - } - }) + jobWaiter.completionFuture onComplete {_ => func(value.get)} } override def isCompleted: Boolean = jobWaiter.jobFinished override def isCancelled: Boolean = _cancelled - override def value: Option[Try[T]] = { - if (jobWaiter.jobFinished) { - Some(awaitResult()) - } else { - None - } - } - - private def awaitResult(): Try[T] = { - jobWaiter.awaitResult() match { - case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception) => scala.util.Failure(e) - } - } + override def value: Option[Try[T]] = + jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } +/** + * Handle via which a "run" function passed to a [[ComplexFutureAction]] + * can submit jobs for execution. + */ +@DeveloperApi +trait JobSubmitter { + /** + * Submit a job for execution and return a FutureAction holding the result. + * This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R): FutureAction[R] +} + + /** * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, - * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the - * action thread if it is being blocked by a job. + * takeSample. Cancellation works by setting the cancelled flag to true and cancelling any pending + * jobs. */ -class ComplexFutureAction[T] extends FutureAction[T] { +@DeveloperApi +class ComplexFutureAction[T](run : JobSubmitter => Future[T]) + extends FutureAction[T] { self => - // Pointer to the thread that is executing the action. It is set when the action is run. - @volatile private var thread: Thread = _ + @volatile private var _cancelled = false - // A flag indicating whether the future has been cancelled. This is used in case the future - // is cancelled before the action was even run (and thus we have no thread to interrupt). - @volatile private var _cancelled: Boolean = false - - @volatile private var jobs: Seq[Int] = Nil + @volatile private var subActions: List[FutureAction[_]] = Nil // A promise used to signal the future. - private val p = promise[T]() + private val p = Promise[T]().tryCompleteWith(run(jobSubmitter)) - override def cancel(): Unit = this.synchronized { + override def cancel(): Unit = synchronized { _cancelled = true - if (thread != null) { - thread.interrupt() - } - } - - /** - * Executes some action enclosed in the closure. To properly enable cancellation, the closure - * should use runJob implementation in this promise. See takeAsync for example. - */ - def run(func: => T)(implicit executor: ExecutionContext): this.type = { - scala.concurrent.future { - thread = Thread.currentThread - try { - p.success(func) - } catch { - case e: Exception => p.failure(e) - } finally { - // This lock guarantees when calling `thread.interrupt()` in `cancel`, - // thread won't be set to null. - ComplexFutureAction.this.synchronized { - thread = null - } - } - } - this + p.tryFailure(new SparkException("Action has been cancelled")) + subActions.foreach(_.cancel()) } - /** - * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext - * to enable cancellation. - */ - def runJob[T, U, R]( + private def jobSubmitter = new JobSubmitter { + def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R) { - // If the action hasn't been cancelled yet, submit the job. The check and the submitJob - // command need to be in an atomic block. - val job = this.synchronized { + resultFunc: => R): FutureAction[R] = self.synchronized { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. if (!isCancelled) { - rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) + val job = rdd.context.submitJob( + rdd, + processPartition, + partitions, + resultHandler, + resultFunc) + subActions = job :: subActions + job } else { throw new SparkException("Action has been cancelled") } } - - this.jobs = jobs ++ job.jobIds - - // Wait for the job to complete. If the action is cancelled (with an interrupt), - // cancel the job and stop the execution. This is not in a synchronized block because - // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. - try { - Await.ready(job, Duration.Inf) - } catch { - case e: InterruptedException => - job.cancel() - throw new SparkException("Action has been cancelled") - } } override def isCancelled: Boolean = _cancelled @@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds: Seq[Int] = jobs + def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) } + private[spark] class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) extends JavaFutureAction[T] { @@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S Await.ready(futureAction, timeout) futureAction.value.get match { case scala.util.Success(value) => converter(value) - case Failure(exception) => + case scala.util.Failure(exception) => if (isCancelled) { throw new CancellationException("Job cancelled").initCause(exception) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index d5e853613b05b..14f541f937b4c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,13 +19,12 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.util.ThreadUtils - import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.reflect.ClassTag -import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.util.ThreadUtils /** * A set of asynchronous RDD actions available through an implicit conversion. @@ -65,17 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { - val f = new ComplexFutureAction[Seq[T]] val callSite = self.context.getCallSite - - f.run { - // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which - // is a cached thread pool. - val results = new ArrayBuffer[T](num) - val totalParts = self.partitions.length - var partsScanned = 0 - self.context.setCallSite(callSite) - while (results.size < num && partsScanned < totalParts) { + val localProperties = self.context.getLocalProperties + // Cached thread pool to handle aggregation of subtasks. + implicit val executionContext = AsyncRDDActions.futureExecutionContext + val results = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + + /* + Recursively triggers jobs to scan partitions until either the requested + number of elements are retrieved, or the partitions to scan are exhausted. + This implementation is non-blocking, asynchronously handling the + results of each job and triggering the next job using callbacks on futures. + */ + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + if (results.size >= num || partsScanned >= totalParts) { + Future.successful(results.toSeq) + } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 @@ -97,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val buf = new Array[Array[T]](p.size) - f.runJob(self, + self.context.setCallSite(callSite) + self.context.setLocalProperties(localProperties) + val job = jobSubmitter.submitJob(self, (it: Iterator[T]) => it.take(left).toArray, p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - - buf.foreach(results ++= _.take(num - results.size)) - partsScanned += numPartsToTry + job.flatMap {_ => + buf.foreach(results ++= _.take(num - results.size)) + continue(partsScanned + numPartsToTry) + } } - results.toSeq - }(AsyncRDDActions.futureExecutionContext) - f + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5582720bbcff2..8d0e0c8624a55 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.mutable.{HashMap, HashSet, Stack} +import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -610,11 +611,12 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - waiter.awaitResult() match { - case JobSucceeded => + Await.ready(waiter.completionFuture, atMost = Duration.Inf) + waiter.completionFuture.value.get match { + case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - case JobFailed(exception: Exception) => + case scala.util.Failure(exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 382b09422a4a0..4326135186a73 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,6 +17,10 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{Future, Promise} + /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. @@ -28,17 +32,15 @@ private[spark] class JobWaiter[T]( resultHandler: (Int, T) => Unit) extends JobListener { - private var finishedTasks = 0 - - // Is the job as a whole finished (succeeded or failed)? - @volatile - private var _jobFinished = totalTasks == 0 - - def jobFinished: Boolean = _jobFinished - + private val finishedTasks = new AtomicInteger(0) // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. - private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + private val jobPromise: Promise[Unit] = + if (totalTasks == 0) Promise.successful(()) else Promise() + + def jobFinished: Boolean = jobPromise.isCompleted + + def completionFuture: Future[Unit] = jobPromise.future /** * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled @@ -49,29 +51,17 @@ private[spark] class JobWaiter[T]( dagScheduler.cancelJob(jobId) } - override def taskSucceeded(index: Int, result: Any): Unit = synchronized { - if (_jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + override def taskSucceeded(index: Int, result: Any): Unit = { + // resultHandler call must be synchronized in case resultHandler itself is not thread safe. + synchronized { + resultHandler(index, result.asInstanceOf[T]) } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - _jobFinished = true - jobResult = JobSucceeded - this.notifyAll() + if (finishedTasks.incrementAndGet() == totalTasks) { + jobPromise.success(()) } } - override def jobFailed(exception: Exception): Unit = synchronized { - _jobFinished = true - jobResult = JobFailed(exception) - this.notifyAll() - } + override def jobFailed(exception: Exception): Unit = + jobPromise.failure(exception) - def awaitResult(): JobResult = synchronized { - while (!_jobFinished) { - this.wait() - } - return jobResult - } } diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala new file mode 100644 index 0000000000000..01694a6e6f741 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.UUID +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.mutable + +/** + * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. + * This is intended for testing purposes, primarily to make locks, semaphores, and + * other constructs that would not survive serialization available from within tasks. + * A Smuggle reference is itself serializable, but after being serialized and + * deserialized, it still refers to the same underlying "smuggled" object, as long + * as it was deserialized within the same JVM. This can be useful for tests that + * depend on the timing of task completion to be deterministic, since one can "smuggle" + * a lock or semaphore into the task, and then the task can block until the test gives + * the go-ahead to proceed via the lock. + */ +class Smuggle[T] private(val key: Symbol) extends Serializable { + def smuggledObject: T = Smuggle.get(key) +} + + +object Smuggle { + /** + * Wraps the specified object to be smuggled into a serialized task without + * being serialized itself. + * + * @param smuggledObject + * @tparam T + * @return Smuggle wrapper around smuggledObject. + */ + def apply[T](smuggledObject: T): Smuggle[T] = { + val key = Symbol(UUID.randomUUID().toString) + lock.writeLock().lock() + try { + smuggledObjects += key -> smuggledObject + } finally { + lock.writeLock().unlock() + } + new Smuggle(key) + } + + private val lock = new ReentrantReadWriteLock + private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any] + + private def get[T](key: Symbol) : T = { + lock.readLock().lock() + try { + smuggledObjects(key).asInstanceOf[T] + } finally { + lock.readLock().unlock() + } + } + + /** + * Implicit conversion of a Smuggle wrapper to the object being smuggled. + * + * @param smuggle the wrapper to unpack. + * @tparam T + * @return the smuggled object represented by the wrapper. + */ + implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject + +} diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 46516e8d25298..5483f2b8434aa 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont Set(firstJobId, secondJobId)) } } + + test("getJobIdsForGroup() with takeAsync()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId)) + } + } + + test("getJobIdsForGroup() with takeAsync() across multiple partitions") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2 + } + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index ec99f2a1bad66..de015ebd5d237 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore -import scala.concurrent.{Await, TimeoutException} +import scala.concurrent._ import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark._ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @@ -197,4 +197,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim Await.result(f, Duration(20, "milliseconds")) } } + + private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { + val executionContextInvoked = Promise[Unit] + val fakeExecutionContext = new ExecutionContext { + override def execute(runnable: Runnable): Unit = { + executionContextInvoked.success(()) + } + override def reportFailure(t: Throwable): Unit = () + } + val starter = Smuggle(new Semaphore(0)) + starter.drainPermits() + val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => starter.acquire(1); itr} + val f = action(rdd) + f.onComplete(_ => ())(fakeExecutionContext) + // Here we verify that registering the callback didn't cause a thread to be consumed. + assert(!executionContextInvoked.isCompleted) + // Now allow the executors to proceed with task processing. + starter.release(rdd.partitions.length) + // Waiting for the result verifies that the tasks were successfully processed. + Await.result(executionContextInvoked.future, atMost = 15.seconds) + } + + test("SimpleFutureAction callback must not consume a thread while waiting") { + testAsyncAction(_.countAsync()) + } + + test("ComplexFutureAction callback must not consume a thread while waiting") { + testAsyncAction((_.takeAsync(100))) + } }