Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

use TaskMetrics to gather all stats; lots of plumbing to get it all t…

…he way back to driver
  • Loading branch information...
commit b7d9e2439445da9b1ca8709f4ad8fcac9927dd76 1 parent 04e828f
Imran Rashid authored
View
9 core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -7,10 +7,13 @@ case class TaskMetrics(
val remoteBlocksFetched: Option[Int],
val localBlocksFetched: Option[Int],
val remoteFetchWaitTime: Option[Long],
- val remoteBytesRead: Option[Long]
+ val remoteBytesRead: Option[Long],
+ val shuffleBytesWritten: Option[Long]
)
object TaskMetrics {
- private[spark] def apply(task: Task[_]) : TaskMetrics =
- TaskMetrics(None, None, None, task.remoteFetchWaitTime, task.remoteReadBytes)
+ private[spark] def apply(task: Task[_]) : TaskMetrics = {
+ TaskMetrics(task.totalBlocksFetched, task.remoteBlocksFetched, task.localBlocksFetched,
+ task.remoteFetchWaitTime, task.remoteReadBytes, task.shuffleBytesWritten)
+ }
}
View
3  core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -116,6 +116,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
context.task.remoteFetchTime = Some(fetchItr.remoteFetchTime)
context.task.remoteFetchWaitTime = Some(fetchItr.remoteFetchWaitTime)
context.task.remoteReadBytes = Some(fetchItr.remoteBytesRead)
+ context.task.totalBlocksFetched = Some(fetchItr.totalBlocks)
+ context.task.localBlocksFetched = Some(fetchItr.numLocalBlocks)
+ context.task.remoteBlocksFetched = Some(fetchItr.numRemoteBlocks)
}
}
JavaConversions.mapAsScalaMap(map).iterator
View
15 core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -7,9 +7,10 @@ import java.util.concurrent.Future
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import spark._
+import executor.TaskMetrics
import spark.partial.ApproximateActionListener
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
@@ -42,8 +43,9 @@ class DAGScheduler(
reason: TaskEndReason,
result: Any,
accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo) {
- eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo))
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) {
+ eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}
// Called by TaskScheduler when an executor fails.
@@ -77,7 +79,7 @@ class DAGScheduler(
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
- private val sparkListeners = Traversable[SparkListener]()
+ private[spark] var sparkListeners = ArrayBuffer[SparkListener]()
var cacheLocs = new HashMap[Int, Array[List[String]]]
@@ -491,6 +493,7 @@ class DAGScheduler(
}
pendingTasks(stage) -= task
stageToInfos(stage).taskInfos += event.taskInfo
+ stageToInfos(stage).taskMetrics += event.taskMetrics
task match {
case rt: ResultTask[_, _] =>
resultStageToJob.get(stage) match {
@@ -512,10 +515,6 @@ class DAGScheduler(
case smt: ShuffleMapTask =>
val status = event.result.asInstanceOf[MapStatus]
- smt.totalBytesWritten match {
- case Some(b) => stageToInfos(stage).shuffleBytesWritten += b
- case None => throw new RuntimeException("shuffle stask completed without tracking bytes written")
- }
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) {
View
4 core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -4,6 +4,7 @@ import cluster.TaskInfo
import scala.collection.mutable.Map
import spark._
+import executor.TaskMetrics
/**
* Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
@@ -27,7 +28,8 @@ private[spark] case class CompletionEvent(
reason: TaskEndReason,
result: Any,
accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo)
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
View
5 core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -81,9 +81,6 @@ private[spark] class ShuffleMapTask(
with Externalizable
with Logging {
-
- var totalBytesWritten : Option[Long] = None
-
protected def this() = this(0, null, null, 0, null)
var split = if (rdd == null) {
@@ -144,7 +141,7 @@ private[spark] class ShuffleMapTask(
totalBytes += size
compressedSizes(i) = MapOutputTracker.compressSize(size)
}
- totalBytesWritten = Some(totalBytes)
+ shuffleBytesWritten = Some(totalBytes)
return new MapStatus(blockManager.blockManagerId, compressedSizes)
} finally {
View
20 core/src/main/scala/spark/scheduler/StageInfo.scala
@@ -3,17 +3,29 @@ package spark.scheduler
import cluster.TaskInfo
import collection._
import spark.util.Distribution
+import spark.executor.TaskMetrics
case class StageInfo(
val stage: Stage,
val taskInfos: mutable.Buffer[TaskInfo] = mutable.Buffer[TaskInfo](),
- val shuffleBytesWritten : mutable.Buffer[Long] = mutable.Buffer[Long](),
- val shuffleBytesRead : mutable.Buffer[Long] = mutable.Buffer[Long]()
+ val taskMetrics: mutable.Buffer[TaskMetrics] = mutable.Buffer[TaskMetrics]()
) {
- def name = stage.rdd.name + "(" + stage.origin + ")"
+ override def toString = stage.rdd.toString
def getTaskRuntimeDistribution = {
- new Distribution(taskInfos.map{_.duration.toDouble})
+ Distribution(taskInfos.map{_.duration.toDouble})
+ }
+
+ def getShuffleBytesWrittenDistribution = {
+ Distribution(taskMetrics.flatMap{_.shuffleBytesWritten.map{_.toDouble}})
+ }
+
+ def getRemoteFetchWaitTimeDistribution = {
+ Distribution(taskMetrics.flatMap{_.remoteFetchWaitTime.map{_.toDouble}})
+ }
+
+ def getRemoteBytesReadDistribution = {
+ Distribution(taskMetrics.flatMap{_.remoteBytesRead.map{_.toDouble}})
}
}
View
6 core/src/main/scala/spark/scheduler/Task.scala
@@ -21,6 +21,12 @@ private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
var remoteReadBytes : Option[Long] = None
var remoteFetchWaitTime : Option[Long] = None
var remoteFetchTime : Option[Long] = None
+ var totalBlocksFetched : Option[Int] = None
+ var remoteBlocksFetched: Option[Int] = None
+ var localBlocksFetched: Option[Int] = None
+
+ var shuffleBytesWritten : Option[Long] = None
+
}
/**
View
4 core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -9,7 +9,7 @@ import spark.executor.TaskMetrics
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
private[spark]
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], val metrics: TaskMetrics) extends Externalizable {
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) extends Externalizable {
def this() = this(null.asInstanceOf[T], null, null)
override def writeExternal(out: ObjectOutput) {
@@ -19,6 +19,7 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], val metrics:
out.writeLong(key)
out.writeObject(value)
}
+ out.writeObject(metrics)
}
override def readExternal(in: ObjectInput) {
@@ -32,5 +33,6 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], val metrics:
accumUpdates(in.readLong()) = in.readObject()
}
}
+ metrics = in.readObject().asInstanceOf[TaskMetrics]
}
}
View
6 core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -1,16 +1,18 @@
package spark.scheduler
-import cluster.TaskInfo
+import spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
import spark.TaskEndReason
+import spark.executor.TaskMetrics
/**
* Interface for getting events back from the TaskScheduler.
*/
private[spark] trait TaskSchedulerListener {
// A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo): Unit
+ def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
// A node was lost from the cluster.
def executorLost(execId: String): Unit
View
2  core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -1,5 +1,7 @@
package spark.scheduler.cluster
+import spark.executor.TaskMetrics
+
/**
* Information about a running task attempt inside a TaskSet.
*/
View
6 core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -259,7 +259,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
- sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info)
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks) {
@@ -290,7 +290,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
finished(index) = true
tasksFinished += 1
sched.taskSetFinished(this)
@@ -378,7 +378,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info)
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
View
4 core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -87,7 +87,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates, info)
+ listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, null)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -98,7 +98,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
} else {
// TODO: Do something nicer here to return all the way to the user
if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, new ExceptionFailure(t), null, null, info)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null, info, null)
}
}
}
View
10 core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -265,7 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
- runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null))
+ runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null))
}
}
}
@@ -463,14 +463,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
val noAccum = Map[Long, Any]()
// We rely on the event queue being ordered and increasing the generation number by 1
// should be ignored for being too old
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
// should work because it's a non-failed host
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
// should be ignored for being too old
- runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null))
+ runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
taskSet.tasks(1).generation = newGeneration
val secondStage = interceptStage(reduceRdd) {
- runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null))
+ runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
}
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
Please sign in to comment.
Something went wrong with that request. Please try again.