Skip to content

Commit

Permalink
Added a submitJob interface that returns a Future of the result.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Sep 18, 2013
1 parent 1cb42e6 commit 37d8f37
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 87 deletions.
50 changes: 50 additions & 0 deletions core/src/main/scala/org/apache/spark/FutureJob.scala
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import java.util.concurrent.{ExecutionException, TimeUnit, Future}

import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}

class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: () => T)
extends Future[T] {

override def isDone: Boolean = jobWaiter.jobFinished

override def cancel(mayInterruptIfRunning: Boolean): Boolean = {
jobWaiter.kill()
true
}

override def isCancelled: Boolean = {
throw new UnsupportedOperationException
}

override def get(): T = {
jobWaiter.awaitResult() match {
case JobSucceeded =>
resultFunc()
case JobFailed(e: Exception, _) =>
throw new ExecutionException(e)
}
}

override def get(timeout: Long, unit: TimeUnit): T = {
throw new UnsupportedOperationException
}
}
19 changes: 19 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import java.io._
import java.net.URI
import java.util.Properties
import java.util.concurrent.Future
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.Map
Expand Down Expand Up @@ -812,6 +813,24 @@ class SparkContext(
result
}

def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitionResultHandler: (Int, U) => Unit,
resultFunc: () => R): Future[R] =
{
val callSite = Utils.formatSparkCallSite
val waiter = dagScheduler.submitJob(
rdd,
(context: TaskContext, iter: Iterator[T]) => processPartition(iter),
0 until rdd.partitions.size,
callSite,
allowLocal = false,
partitionResultHandler,
null)
new FutureJob(waiter, resultFunc)
}

/**
* Kill a running job.
*/
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.rdd

import java.util.Random
import java.util.concurrent.Future

import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
Expand Down Expand Up @@ -561,6 +562,15 @@ abstract class RDD[T: ClassManifest](
Array.concat(results: _*)
}

/**
* Return a future for retrieving the results of a collect in an asynchronous fashion.
*/
def collectAsync(): Future[Seq[T]] = {
val results = new ArrayBuffer[T]
sc.submitJob[T, Array[T], Seq[T]](
this, _.toArray, (index, data) => results ++= data, () => results)
}

/**
* Return an array that contains all of the elements in this RDD.
*/
Expand Down
134 changes: 68 additions & 66 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Expand Up @@ -105,13 +105,15 @@ class DAGScheduler(

private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]

val nextJobId = new AtomicInteger(0)
private[scheduler] val nextJobId = new AtomicInteger(0)

val nextStageId = new AtomicInteger(0)
def numTotalJobs: Int = nextJobId.get()

val stageIdToStage = new TimeStampedHashMap[Int, Stage]
private val nextStageId = new AtomicInteger(0)

val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private val stageIdToStage = new TimeStampedHashMap[Int, Stage]

private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]

private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]

Expand Down Expand Up @@ -263,54 +265,50 @@ class DAGScheduler(
}

/**
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
* JobWaiter whose getResult() method will return the result of the job when it is complete.
*
* The job is assumed to have at least one partition; zero partition jobs should be handled
* without a JobSubmitted event.
* Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
* can be used to block until the the job finishes executing or can be used to kill the job.
* If the given RDD does not contain any partitions, the function returns None.
*/
private[scheduler] def prepareJob[T, U: ClassManifest](
finalRdd: RDD[T],
def submitJob[T, U](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties = null)
: (JobSubmitted, JobWaiter[U]) =
properties: Properties = null): JobWaiter[U] =
{
val jobId = nextJobId.getAndIncrement()
if (partitions.size == 0) {
return new JobWaiter[U](this, jobId, 0, resultHandler)
}

// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions).foreach { p =>
throw new IllegalArgumentException(
"Attempting to access a non-existent partition: " + p + ". " +
"Total number of partitions: " + maxPartitions)
}

assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
properties)
(toSubmit, waiter)
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
waiter, properties))
waiter
}

def runJob[T, U: ClassManifest](
finalRdd: RDD[T],
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties = null)
{
if (partitions.size == 0) {
return
}

// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = finalRdd.partitions.length
partitions.find(p => p >= maxPartitions).foreach { p =>
throw new IllegalArgumentException(
"Attempting to access a non-existent partition: " + p + ". " +
"Total number of partitions: " + maxPartitions)
}

val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
eventQueue.put(toSubmit)
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception, _) =>
Expand All @@ -331,45 +329,50 @@ class DAGScheduler(
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
val jobId = nextJobId.getAndIncrement()
eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite,
listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}

/**
* Kill a job that is running or waiting in the queue.
*/
def killJob(jobId: Int): Unit = this.synchronized {
activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job))
}

private def killJob(job: ActiveJob): Unit = this.synchronized {
logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
activeJobs.remove(job)
idToActiveJob.remove(job.jobId)
val stage = job.finalStage
resultStageToJob.remove(stage)
killStage(job, stage)
val e = new SparkException("Job killed")
job.listener.jobFailed(e)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, None)))
}

private def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized {
// TODO: Can we reuse taskSetFailed?
logInfo("Killing Stage %s".format(stage.id))
stageIdToStage.remove(stage.id)
if (stage.isShuffleMap) {
shuffleToMapStage.remove(stage.id)
}
waiting.remove(stage)
pendingTasks.remove(stage)
taskSched.killTasks(stage.id)

if (running.contains(stage)) {
running.remove(stage)
def killJob(job: ActiveJob): Unit = this.synchronized {
logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
activeJobs.remove(job)
idToActiveJob.remove(job.jobId)
val stage = job.finalStage
resultStageToJob.remove(stage)
killStage(job, stage)
val e = new SparkException("Job killed")
listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, Some(stage))))
job.listener.jobFailed(e)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, None)))
}

stage.parents.foreach(parentStage => killStage(job, parentStage))
//stageToInfos -= stage
def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized {
// TODO: Can we reuse taskSetFailed?
logInfo("Killing Stage %s".format(stage.id))
stageIdToStage.remove(stage.id)
if (stage.isShuffleMap) {
shuffleToMapStage.remove(stage.id)
}
waiting.remove(stage)
pendingTasks.remove(stage)
taskSched.killTasks(stage.id)

if (running.contains(stage)) {
running.remove(stage)
val e = new SparkException("Job killed")
listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, Some(stage))))
}

stage.parents.foreach(parentStage => killStage(job, parentStage))
//stageToInfos -= stage
}
}

/**
Expand All @@ -378,9 +381,8 @@ class DAGScheduler(
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
val jobId = nextJobId.getAndIncrement()
val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
val finalStage = newStage(rdd, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
Expand Down
Expand Up @@ -32,9 +32,10 @@ import org.apache.spark.executor.TaskMetrics
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
*/
private[spark] sealed trait DAGSchedulerEvent
private[scheduler] sealed trait DAGSchedulerEvent

private[spark] case class JobSubmitted(
private[scheduler] case class JobSubmitted(
jobId: Int,
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
Expand All @@ -44,9 +45,10 @@ private[spark] case class JobSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent

private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent

private[spark] case class CompletionEvent(
private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
Expand All @@ -55,10 +57,12 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent

private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
private[scheduler]
case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent

private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent

private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent

private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
Expand Up @@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
})

metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] {
override def getValue: Int = dagScheduler.nextJobId.get()
override def getValue: Int = dagScheduler.numTotalJobs
})

metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] {
Expand Down

0 comments on commit 37d8f37

Please sign in to comment.