Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-2099. Report progress while task is running. #1056

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import akka.actor.Actor
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.scheduler.TaskScheduler

/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
* components to convey liveness or execution information for in-progress tasks.
*/
private[spark] case class Heartbeat(
executorId: String,
taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
blockManagerId: BlockManagerId)

private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)

/**
* Lives in the driver to receive heartbeats from executors..
*/
private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor {
override def receive = {
case Heartbeat(executorId, taskMetrics, blockManagerId) =>
val response = HeartbeatResponse(
!scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId))
sender ! response
}
}
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary
import akka.actor.Props

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -307,6 +308,8 @@ class SparkContext(config: SparkConf) extends Logging {

// Create and start the scheduler
private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master)
private val heartbeatReceiver = env.actorSystem.actorOf(
Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver")
@volatile private[spark] var dagScheduler: DAGScheduler = _
try {
dagScheduler = new DAGScheduler(this)
Expand Down Expand Up @@ -992,6 +995,7 @@ class SparkContext(config: SparkConf) extends Logging {
if (dagSchedulerCopy != null) {
env.metricsSystem.report()
metadataCleaner.cancel()
env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
dagSchedulerCopy.stop()
taskScheduler = null
Expand Down
8 changes: 1 addition & 7 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,7 @@ object SparkEnv extends Logging {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
val driverHost: String = conf.get("spark.driver.host", "localhost")
val driverPort: Int = conf.getInt("spark.driver.port", 7077)
Utils.checkHost(driverHost, "Expected hostname")
val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name"
val timeout = AkkaUtils.lookupTimeout(conf)
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
AkkaUtils.makeDriverRef(name, conf, actorSystem)
}
}

Expand Down
55 changes: 49 additions & 6 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import java.util.concurrent._

import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark._
import org.apache.spark.scheduler._
Expand All @@ -48,6 +48,8 @@ private[spark] class Executor(

private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))

@volatile private var isStopped = false

// No ip or host:port - just hostname
Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
// must not have port specified.
Expand Down Expand Up @@ -107,6 +109,8 @@ private[spark] class Executor(
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]

startDriverHeartbeater()

def launchTask(
context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, taskName, serializedTask)
Expand All @@ -121,8 +125,10 @@ private[spark] class Executor(
}
}

def stop(): Unit = {
def stop() {
env.metricsSystem.report()
isStopped = true
threadPool.shutdown()
}

/** Get the Yarn approved local directories. */
Expand All @@ -141,11 +147,12 @@ private[spark] class Executor(
}

class TaskRunner(
execBackend: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer)
execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
extends Runnable {

@volatile private var killed = false
@volatile private var task: Task[Any] = _
@volatile var task: Task[Any] = _
@volatile var attemptedTask: Option[Task[Any]] = None

def kill(interruptThread: Boolean) {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
Expand All @@ -162,7 +169,6 @@ private[spark] class Executor(
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var attemptedTask: Option[Task[Any]] = None
var taskStart: Long = 0
def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
val startGCTime = gcTime
Expand Down Expand Up @@ -204,7 +210,6 @@ private[spark] class Executor(
val afterSerialization = System.currentTimeMillis()

for (m <- task.metrics) {
m.hostname = Utils.localHostName()
m.executorDeserializeTime = taskStart - startTime
m.executorRunTime = taskFinish - taskStart
m.jvmGCTime = gcTime - startGCTime
Expand Down Expand Up @@ -354,4 +359,42 @@ private[spark] class Executor(
}
}
}

def startDriverHeartbeater() {
val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
val timeout = AkkaUtils.lookupTimeout(conf)
val retryAttempts = AkkaUtils.numRetries(conf)
val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)

val t = new Thread() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Utils.namedThreadFactory or otherwise generate a named thread

okay, I now see that you do that further down. Utils.namedThreadFactory still looks like the best option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not able to find Utils.namedThreadFactory, but my understanding is that named thread factories are used when using a thread pool and the threads can't be named directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes, it goes outside the scope of this PR, but one of the things that bugs me about the Spark code is how frequently we directly create threads with new Thread instead of with some standardized factory and perhaps with better defined and utilized thread pools.

override def run() {
// Sleep a random interval so the heartbeats don't end up in sync
Thread.sleep(interval + (math.random * interval).asInstanceOf[Int])

while (!isStopped) {
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
for (taskRunner <- runningTasks.values()) {
if (!taskRunner.attemptedTask.isEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
tasksMetrics += ((taskRunner.taskId, metrics))
}
}
}

val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
retryAttempts, retryIntervalMs, timeout)
if (response.reregisterBlockManager) {
logWarning("Told to re-register on heartbeat")
env.blockManager.reregister()
}
Thread.sleep(interval)
}
}
}
t.setDaemon(true)
t.setName("Driver Heartbeater")
t.start()
}
}
10 changes: 9 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
/**
* :: DeveloperApi ::
* Metrics tracked during the execution of a task.
*
* This class is used to house metrics both for in-progress and completed tasks. In executors,
* both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread
* reads it to send in-progress metrics, and the task thread reads it to send metrics along with
* the completed task.
*
* So, when adding new fields, take into consideration that the whole object can be serialized for
* shipping off at any time to consumers of the SparkListener interface.
*/
@DeveloperApi
class TaskMetrics extends Serializable {
Expand Down Expand Up @@ -143,7 +151,7 @@ class ShuffleReadMetrics extends Serializable {
/**
* Absolute time when this task finished reading shuffle data
*/
var shuffleFinishTime: Long = _
var shuffleFinishTime: Long = -1

/**
* Number of blocks fetched in this shuffle by this task (remote or local)
Expand Down
21 changes: 19 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import scala.reflect.ClassTag
import scala.util.control.NonFatal

import akka.actor._
import akka.actor.OneForOneStrategy
import akka.actor.SupervisorStrategy.Stop
import akka.pattern.ask
import akka.util.Timeout
Expand All @@ -39,8 +38,9 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
import org.apache.spark.storage._
import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat

/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
Expand Down Expand Up @@ -154,6 +154,23 @@ class DAGScheduler(
eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
}

/**
* Update metrics for in-progress tasks and let the master know that the BlockManager is still
* alive. Return true if the driver knows about the given block manager. Otherwise, return false,
* indicating that the block manager should re-register.
*/
def executorHeartbeatReceived(
execId: String,
taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics)
blockManagerId: BlockManagerId): Boolean = {
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
implicit val timeout = Timeout(600 seconds)

Await.result(
blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId),
timeout.duration).asInstanceOf[Boolean]
}

// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
eventProcessActor ! ExecutorLost(execId)
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId)
@DeveloperApi
case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent

@DeveloperApi
case class SparkListenerExecutorMetricsUpdate(
execId: String,
taskMetrics: Seq[(Long, Int, TaskMetrics)])
extends SparkListenerEvent

@DeveloperApi
case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String)
extends SparkListenerEvent
Expand Down Expand Up @@ -158,6 +164,11 @@ trait SparkListener {
* Called when the application ends
*/
def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { }

/**
* Called when the driver receives task metrics from an executor in a heartbeat.
*/
def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { }
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ private[spark] trait SparkListenerBus extends Logging {
foreachListener(_.onApplicationStart(applicationStart))
case applicationEnd: SparkListenerApplicationEnd =>
foreachListener(_.onApplicationEnd(applicationEnd))
case metricsUpdate: SparkListenerExecutorMetricsUpdate =>
foreachListener(_.onExecutorMetricsUpdate(metricsUpdate))
case SparkListenerShutdown =>
}
}
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils


/**
* A unit of execution. We have two kinds of Task's in Spark:
Expand All @@ -44,6 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex

final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
context.taskMetrics.hostname = Utils.localHostName();
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.scheduler

import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId

/**
* Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl.
Expand Down Expand Up @@ -54,4 +56,12 @@ private[spark] trait TaskScheduler {

// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int

/**
* Update metrics for in-progress tasks and let the master know that the BlockManager is still
* alive. Return true if the driver knows about the given block manager. Otherwise, return false,
* indicating that the block manager should re-register.
*/
def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
blockManagerId: BlockManagerId): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.util.Utils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
import akka.actor.Props

/**
* Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
Expand Down Expand Up @@ -320,6 +323,26 @@ private[spark] class TaskSchedulerImpl(
}
}

/**
* Update metrics for in-progress tasks and let the master know that the BlockManager is still
* alive. Return true if the driver knows about the given block manager. Otherwise, return false,
* indicating that the block manager should re-register.
*/
override def executorHeartbeatReceived(
execId: String,
taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
blockManagerId: BlockManagerId): Boolean = {
val metricsWithStageIds = taskMetrics.flatMap {
case (id, metrics) => {
taskIdToTaskSetId.get(id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an unlikely race here where (a) a task heartbeat gets enqueued to be sent (b) the task actually finishes and that message is sent, then taskIdToTaskSetId is cleared (c) the heartbeat arrives. This is possible since the heartbeater and the task execution itself are in different threads. Then you'd get an NPE here. Though extremely unlikely, it might be good to just log a warning and pass if the task set is not found.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per our discussion, this is using options everywhere, so it's good.

.flatMap(activeTaskSets.get)
.map(_.stageId)
.map(x => (id, x, metrics))
}
}
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
}

def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
Expand Down
Loading