Skip to content

Commit

Permalink
Add EventLoop and change DAGScheduler to an EventLoop
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Jan 13, 2015
1 parent 1e42e96 commit 3b2e59c
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 96 deletions.
112 changes: 43 additions & 69 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.scheduler

import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
Expand All @@ -28,8 +29,6 @@ import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import akka.actor._
import akka.actor.SupervisorStrategy.Stop
import akka.pattern.ask
import akka.util.Timeout

Expand All @@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat

/**
Expand Down Expand Up @@ -67,8 +66,6 @@ class DAGScheduler(
clock: Clock = SystemClock)
extends Logging {

import DAGScheduler._

def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
sc,
Expand Down Expand Up @@ -112,42 +109,31 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]

private val dagSchedulerActorSupervisor =
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))

// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()

private[scheduler] var eventProcessActor: ActorRef = _

/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)

/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)

private def initializeEventProcessActor() {
// blocking the thread until supervisor is started, which ensures eventProcessActor is
// not null before any job is submitted
implicit val timeout = Timeout(30 seconds)
val initEventActorReply =
dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
eventProcessActor = Await.result(initEventActorReply, timeout.duration).
asInstanceOf[ActorRef]
}
private val messageScheduler =
Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message"))

initializeEventProcessActor()
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)

// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessActor ! BeginEvent(task, taskInfo)
eventProcessLoop.post(BeginEvent(task, taskInfo))
}

// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(taskInfo: TaskInfo) {
eventProcessActor ! GettingResultEvent(taskInfo)
eventProcessLoop.post(GettingResultEvent(taskInfo))
}

// Called by TaskScheduler to report task completions or failures.
Expand All @@ -158,7 +144,8 @@ class DAGScheduler(
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
eventProcessLoop.post(
CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}

/**
Expand All @@ -180,18 +167,18 @@ class DAGScheduler(

// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
eventProcessActor ! ExecutorLost(execId)
eventProcessLoop.post(ExecutorLost(execId))
}

// Called by TaskScheduler when a host is added
def executorAdded(execId: String, host: String) {
eventProcessActor ! ExecutorAdded(execId, host)
eventProcessLoop.post(ExecutorAdded(execId, host))
}

// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
eventProcessActor ! TaskSetFailed(taskSet, reason)
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}

private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
Expand Down Expand Up @@ -496,8 +483,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventProcessActor ! JobSubmitted(
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties))
waiter
}

Expand Down Expand Up @@ -537,8 +524,8 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
eventProcessActor ! JobSubmitted(
jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}

Expand All @@ -547,19 +534,19 @@ class DAGScheduler(
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
eventProcessActor ! JobCancelled(jobId)
eventProcessLoop.post(JobCancelled(jobId))
}

def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
eventProcessActor ! JobGroupCancelled(groupId)
eventProcessLoop.post(JobGroupCancelled(groupId))
}

/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
eventProcessActor ! AllJobsCancelled
eventProcessLoop.post(AllJobsCancelled)
}

private[scheduler] def doCancelAllJobs() {
Expand All @@ -575,7 +562,7 @@ class DAGScheduler(
* Cancel all jobs associated with a running or scheduled stage.
*/
def cancelStage(stageId: Int) {
eventProcessActor ! StageCancelled(stageId)
eventProcessLoop.post(StageCancelled(stageId))
}

/**
Expand Down Expand Up @@ -1059,16 +1046,16 @@ class DAGScheduler(

if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty && eventProcessActor != null) {
} else if (failedStages.isEmpty && eventProcessLoop != null) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// in that case the event will already have been scheduled. eventProcessLoop may be
// null during unit tests.
// TODO: Cancel running tasks in the stage
import env.actorSystem.dispatcher
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
env.actorSystem.scheduler.scheduleOnce(
RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
Expand Down Expand Up @@ -1326,40 +1313,21 @@ class DAGScheduler(

def stop() {
logInfo("Stopping DAGScheduler")
dagSchedulerActorSupervisor ! PoisonPill
eventProcessLoop.stop()
taskScheduler.stop()
}
}

private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
extends Actor with Logging {

override val supervisorStrategy =
OneForOneStrategy() {
case x: Exception =>
logError("eventProcesserActor failed; shutting down SparkContext", x)
try {
dagScheduler.doCancelAllJobs()
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
dagScheduler.sc.stop()
Stop
}

def receive = {
case p: Props => sender ! context.actorOf(p)
case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor")
}
// Start the event thread at the end of the constructor
eventProcessLoop.start()
}

private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler)
extends Actor with Logging {
private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)
extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging {

/**
* The main event loop of the DAG scheduler.
*/
def receive = {
override def onReceive(event: DAGSchedulerEvent): Unit = event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
listener, properties)
Expand Down Expand Up @@ -1398,7 +1366,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.resubmitFailedStages()
}

override def postStop() {
override def onError(e: Throwable): Unit = {
logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e)
try {
dagScheduler.doCancelAllJobs()
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
dagScheduler.sc.stop()
}

override def onStop() {
// Cancel any active jobs in postStop hook
dagScheduler.cleanUpAfterSchedulerStop()
}
Expand All @@ -1408,9 +1386,5 @@ private[spark] object DAGScheduler {
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
val RESUBMIT_TIMEOUT = 200.milliseconds

// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 10L
val RESUBMIT_TIMEOUT = 200
}
110 changes: 110 additions & 0 deletions core/src/main/scala/org/apache/spark/util/EventLoop.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}

import scala.util.control.NonFatal

import org.apache.spark.Logging

/**
* An event loop to receive events from the caller and process all events in the event thread. It
* will start an exclusive event thread to process all events.
*/
abstract class EventLoop[E](name: String) extends Logging {

private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]()

private val eventThread = new Thread(name) {
setDaemon(true)

override def run(): Unit = {
try {
while (true) {
val event = eventQueue.take()
try {
onReceive(event)
} catch {
case NonFatal(e) => {
try {
onError(e)
} catch {
case NonFatal(e) => logError("Unexpected error in " + name, e)
}
}
}
}
} catch {
case ie: InterruptedException => // exit even if eventQueue is not empty
case NonFatal(e) => logError("Unexpected error in " + name, e)
}
}

}

def start(): Unit = {
// Call onStart before starting the event thread to make sure it happens before onReceive
onStart()
eventThread.start()
}

def stop(): Unit = {
eventThread.interrupt()
eventThread.join()
// Call onStop after the event thread exits to make sure onReceive happens before onStop
onStop()
}

/**
* Put the event into the event queue. The event thread will process it later.
*/
def post(event: E): Unit = {
eventQueue.put(event)
}

/**
* Return if the event thread has already been started but not yet stopped.
*/
def isActive: Boolean = eventThread.isAlive

/**
* Invoke when `start()` is called. It's also invoked before the event thread starts.
*/
def onStart(): Unit = {}

/**
* Invoke when `stop()` is called and the event thread exits.
*/
def onStop(): Unit = {}

/**
* Invoke in the event thread when polling events from the event queue.
*
* Note: Should avoid calling blocking actions in `onReceive`, or the event thread will be blocked
* and cannot process events in time. If you want to call some blocking actions, run them in
* another thread.
*/
def onReceive(event: E): Unit

/**
* Invoke if `onReceive` throws any non fatal error. `onError` must not throw any non fatal error.
*/
def onError(e: Throwable): Unit

}
Loading

0 comments on commit 3b2e59c

Please sign in to comment.