From 3b2e59c096b9a26bd76a72a8097c4d2fe356b676 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 13 Jan 2015 16:40:33 +0800 Subject: [PATCH] Add EventLoop and change DAGScheduler to an EventLoop --- .../apache/spark/scheduler/DAGScheduler.scala | 112 +++++++----------- .../org/apache/spark/util/EventLoop.scala | 110 +++++++++++++++++ .../spark/scheduler/DAGSchedulerSuite.scala | 45 +++---- .../apache/spark/util/EventLoopSuite.scala | 83 +++++++++++++ 4 files changed, 254 insertions(+), 96 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/EventLoop.scala create mode 100644 core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 61d09d73e17cb..2e60d696c81e3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties +import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} @@ -28,8 +29,6 @@ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal -import akka.actor._ -import akka.actor.SupervisorStrategy.Stop import akka.pattern.ask import akka.util.Timeout @@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ -import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} +import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils} import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** @@ -67,8 +66,6 @@ class DAGScheduler( clock: Clock = SystemClock) extends Logging { - import DAGScheduler._ - def this(sc: SparkContext, taskScheduler: TaskScheduler) = { this( sc, @@ -112,14 +109,10 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] - private val dagSchedulerActorSupervisor = - env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) - // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() - private[scheduler] var eventProcessActor: ActorRef = _ /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) @@ -127,27 +120,20 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) - private def initializeEventProcessActor() { - // blocking the thread until supervisor is started, which ensures eventProcessActor is - // not null before any job is submitted - implicit val timeout = Timeout(30 seconds) - val initEventActorReply = - dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this)) - eventProcessActor = Await.result(initEventActorReply, timeout.duration). - asInstanceOf[ActorRef] - } + private val messageScheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message")) - initializeEventProcessActor() + private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { - eventProcessActor ! BeginEvent(task, taskInfo) + eventProcessLoop.post(BeginEvent(task, taskInfo)) } // Called to report that a task has completed and results are being fetched remotely. def taskGettingResult(taskInfo: TaskInfo) { - eventProcessActor ! GettingResultEvent(taskInfo) + eventProcessLoop.post(GettingResultEvent(taskInfo)) } // Called by TaskScheduler to report task completions or failures. @@ -158,7 +144,8 @@ class DAGScheduler( accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) + eventProcessLoop.post( + CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) } /** @@ -180,18 +167,18 @@ class DAGScheduler( // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { - eventProcessActor ! ExecutorLost(execId) + eventProcessLoop.post(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added def executorAdded(execId: String, host: String) { - eventProcessActor ! ExecutorAdded(execId, host) + eventProcessLoop.post(ExecutorAdded(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. def taskSetFailed(taskSet: TaskSet, reason: String) { - eventProcessActor ! TaskSetFailed(taskSet, reason) + eventProcessLoop.post(TaskSetFailed(taskSet, reason)) } private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { @@ -496,8 +483,8 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventProcessActor ! JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + eventProcessLoop.post(JobSubmitted( + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)) waiter } @@ -537,8 +524,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventProcessActor ! JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) + eventProcessLoop.post(JobSubmitted( + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) listener.awaitResult() // Will throw an exception if the job fails } @@ -547,19 +534,19 @@ class DAGScheduler( */ def cancelJob(jobId: Int) { logInfo("Asked to cancel job " + jobId) - eventProcessActor ! JobCancelled(jobId) + eventProcessLoop.post(JobCancelled(jobId)) } def cancelJobGroup(groupId: String) { logInfo("Asked to cancel job group " + groupId) - eventProcessActor ! JobGroupCancelled(groupId) + eventProcessLoop.post(JobGroupCancelled(groupId)) } /** * Cancel all jobs that are running or waiting in the queue. */ def cancelAllJobs() { - eventProcessActor ! AllJobsCancelled + eventProcessLoop.post(AllJobsCancelled) } private[scheduler] def doCancelAllJobs() { @@ -575,7 +562,7 @@ class DAGScheduler( * Cancel all jobs associated with a running or scheduled stage. */ def cancelStage(stageId: Int) { - eventProcessActor ! StageCancelled(stageId) + eventProcessLoop.post(StageCancelled(stageId)) } /** @@ -1059,16 +1046,16 @@ class DAGScheduler( if (disallowStageRetryForTest) { abortStage(failedStage, "Fetch failure will not retry stage due to testing config") - } else if (failedStages.isEmpty && eventProcessActor != null) { + } else if (failedStages.isEmpty && eventProcessLoop != null) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. eventProcessActor may be + // in that case the event will already have been scheduled. eventProcessLoop may be // null during unit tests. // TODO: Cancel running tasks in the stage - import env.actorSystem.dispatcher logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + s"$failedStage (${failedStage.name}) due to fetch failure") - env.actorSystem.scheduler.scheduleOnce( - RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages) + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) } failedStages += failedStage failedStages += mapStage @@ -1326,40 +1313,21 @@ class DAGScheduler( def stop() { logInfo("Stopping DAGScheduler") - dagSchedulerActorSupervisor ! PoisonPill + eventProcessLoop.stop() taskScheduler.stop() } -} -private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler) - extends Actor with Logging { - - override val supervisorStrategy = - OneForOneStrategy() { - case x: Exception => - logError("eventProcesserActor failed; shutting down SparkContext", x) - try { - dagScheduler.doCancelAllJobs() - } catch { - case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) - } - dagScheduler.sc.stop() - Stop - } - - def receive = { - case p: Props => sender ! context.actorOf(p) - case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor") - } + // Start the event thread at the end of the constructor + eventProcessLoop.start() } -private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler) - extends Actor with Logging { +private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) + extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { /** * The main event loop of the DAG scheduler. */ - def receive = { + override def onReceive(event: DAGSchedulerEvent): Unit = event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) @@ -1398,7 +1366,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule dagScheduler.resubmitFailedStages() } - override def postStop() { + override def onError(e: Throwable): Unit = { + logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) + try { + dagScheduler.doCancelAllJobs() + } catch { + case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) + } + dagScheduler.sc.stop() + } + + override def onStop() { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } @@ -1408,9 +1386,5 @@ private[spark] object DAGScheduler { // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in - val RESUBMIT_TIMEOUT = 200.milliseconds - - // The time, in millis, to wake up between polls of the completion queue in order to potentially - // resubmit failed stages - val POLL_TIMEOUT = 10L + val RESUBMIT_TIMEOUT = 200 } diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala new file mode 100644 index 0000000000000..261bd19937f62 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} + +import scala.util.control.NonFatal + +import org.apache.spark.Logging + +/** + * An event loop to receive events from the caller and process all events in the event thread. It + * will start an exclusive event thread to process all events. + */ +abstract class EventLoop[E](name: String) extends Logging { + + private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]() + + private val eventThread = new Thread(name) { + setDaemon(true) + + override def run(): Unit = { + try { + while (true) { + val event = eventQueue.take() + try { + onReceive(event) + } catch { + case NonFatal(e) => { + try { + onError(e) + } catch { + case NonFatal(e) => logError("Unexpected error in " + name, e) + } + } + } + } + } catch { + case ie: InterruptedException => // exit even if eventQueue is not empty + case NonFatal(e) => logError("Unexpected error in " + name, e) + } + } + + } + + def start(): Unit = { + // Call onStart before starting the event thread to make sure it happens before onReceive + onStart() + eventThread.start() + } + + def stop(): Unit = { + eventThread.interrupt() + eventThread.join() + // Call onStop after the event thread exits to make sure onReceive happens before onStop + onStop() + } + + /** + * Put the event into the event queue. The event thread will process it later. + */ + def post(event: E): Unit = { + eventQueue.put(event) + } + + /** + * Return if the event thread has already been started but not yet stopped. + */ + def isActive: Boolean = eventThread.isAlive + + /** + * Invoke when `start()` is called. It's also invoked before the event thread starts. + */ + def onStart(): Unit = {} + + /** + * Invoke when `stop()` is called and the event thread exits. + */ + def onStop(): Unit = {} + + /** + * Invoke in the event thread when polling events from the event queue. + * + * Note: Should avoid calling blocking actions in `onReceive`, or the event thread will be blocked + * and cannot process events in time. If you want to call some blocking actions, run them in + * another thread. + */ + def onReceive(event: E): Unit + + /** + * Invoke if `onReceive` throws any non fatal error. `onError` must not throw any non fatal error. + */ + def onError(e: Throwable): Unit + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d30eb10bbe947..733a1e37849ba 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.scheduler import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls -import akka.actor._ -import akka.testkit.{ImplicitSender, TestKit, TestActorRef} import org.scalatest.{BeforeAndAfter, FunSuiteLike} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -33,11 +31,20 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite import org.apache.spark.executor.TaskMetrics -class BuggyDAGEventProcessActor extends Actor { - val state = 0 - def receive = { - case _ => throw new SparkException("error") +import scala.util.control.NonFatal + +class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) + extends DAGSchedulerEventProcessLoop(dagScheduler) { + + override def post(event: DAGSchedulerEvent): Unit = { + try { + // Forward event to `onReceive` directly to avoid processing event asynchronously. + onReceive(event) + } catch { + case NonFatal(e) => onError(e) + } } + } /** @@ -65,8 +72,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike - with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -113,7 +119,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null - var dagEventProcessTestActor: TestActorRef[DAGSchedulerEventProcessActor] = null + var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null /** * Set of cache locations to return from our mock BlockManagerMaster. @@ -167,13 +173,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runLocallyWithinThread(job) } } - dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( - Props(classOf[DAGSchedulerEventProcessActor], scheduler))(system) + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } override def afterAll() { super.afterAll() - TestKit.shutdownActorSystem(system) } /** @@ -190,7 +194,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F * DAGScheduler event loop. */ private def runEvent(event: DAGSchedulerEvent) { - dagEventProcessTestActor.receive(event) + dagEventProcessLoopTester.post(event) } /** @@ -397,8 +401,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runLocallyWithinThread(job) } } - dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor]( - Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system) + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) // Because the job wasn't actually cancelled, we shouldn't have received a failure message. @@ -726,18 +729,6 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assert(sc.parallelize(1 to 10, 2).first() === 1) } - test("DAGSchedulerActorSupervisor closes the SparkContext when EventProcessActor crashes") { - val actorSystem = ActorSystem("test") - val supervisor = actorSystem.actorOf( - Props(classOf[DAGSchedulerActorSupervisor], scheduler), "dagSupervisor") - supervisor ! Props[BuggyDAGEventProcessActor] - val child = expectMsgType[ActorRef] - watch(child) - child ! "hi" - expectMsgPF(){ case Terminated(child) => () } - assert(scheduler.sc.dagScheduler === null) - } - test("accumulator not calculated for resubmitted result stage") { //just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala new file mode 100644 index 0000000000000..cc162062bccbd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.FunSuite + +class EventLoopSuite extends FunSuite { + + test("EventLoop") { + val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + buffer += event + } + + override def onError(e: Throwable): Unit = {} + } + eventLoop.start() + (1 to 100).foreach(eventLoop.post) + eventually(timeout(5 seconds), interval(200 millis)) { + assert((1 to 100) === buffer.toSeq) + } + eventLoop.stop() + } + + test("EventLoop: start and stop") { + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = {} + + override def onError(e: Throwable): Unit = {} + } + assert(false === eventLoop.isActive) + eventLoop.start() + assert(true === eventLoop.isActive) + eventLoop.stop() + assert(false === eventLoop.isActive) + } + + test("EventLoop: onError") { + val e = new RuntimeException("Oops") + val receivedError = new AtomicReference[Throwable]() + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw e + } + + override def onError(e: Throwable): Unit = { + receivedError.set(e) + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(200 millis)) { + assert(e === receivedError.get) + } + eventLoop.stop() + } +}