Skip to content

Commit

Permalink
copy taskMetrics only when isLocal is true
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 10, 2014
1 parent 5ca26dc commit 754b5b8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
8 changes: 7 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,13 @@ private[spark] class Executor(
if (!taskRunner.attemptedTask.isEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
metrics.updateShuffleReadMetrics
tasksMetrics += ((taskRunner.taskId, metrics))
if (isLocal) {
// make a deep copy of it
val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
tasksMetrics += ((taskRunner.taskId, copiedMetrics))
} else {
tasksMetrics += ((taskRunner.taskId, metrics))
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.ui.jobs.UIData._
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -242,9 +241,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics,
t.taskMetrics)

// Overwrite task metrics with deepcopy
// TODO: only serialize it in local mode
t.taskMetrics = Some(Utils.deserialize[TaskMetrics](Utils.serialize(taskMetrics)))
// Overwrite task metrics
t.taskMetrics = Some(taskMetrics)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0))
val execId = "exe-1"

def updateTaskMetrics(taskMetrics: TaskMetrics, base: Int) = {
def makeTaskMetrics(base: Int) = {
val taskMetrics = new TaskMetrics()
val shuffleReadMetrics = new ShuffleReadMetrics()
val shuffleWriteMetrics = new ShuffleWriteMetrics()
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
Expand Down Expand Up @@ -173,16 +174,10 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1236L)))
listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L)))

val metrics4 = new TaskMetrics
val metrics5 = new TaskMetrics
val metrics6 = new TaskMetrics
val metrics7 = new TaskMetrics
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
(1234L, 0, 0, updateTaskMetrics(metrics4, 0)))))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
(1235L, 0, 0, updateTaskMetrics(metrics5, 100)))))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
(1236L, 1, 0, updateTaskMetrics(metrics6, 200)))))
(1234L, 0, 0, makeTaskMetrics(0)),
(1235L, 0, 0, makeTaskMetrics(100)),
(1236L, 1, 0, makeTaskMetrics(200)))))

var stage0Data = listener.stageIdToData.get((0, 0)).get
var stage1Data = listener.stageIdToData.get((1, 0)).get
Expand All @@ -207,10 +202,10 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc

// task that was included in a heartbeat
listener.onTaskEnd(SparkListenerTaskEnd(0, 0, taskType, Success, makeTaskInfo(1234L, 1),
updateTaskMetrics(metrics4, 300)))
makeTaskMetrics(300)))
// task that wasn't included in a heartbeat
listener.onTaskEnd(SparkListenerTaskEnd(1, 0, taskType, Success, makeTaskInfo(1237L, 1),
updateTaskMetrics(metrics7, 400)))
makeTaskMetrics(400)))

stage0Data = listener.stageIdToData.get((0, 0)).get
stage1Data = listener.stageIdToData.get((1, 0)).get
Expand Down

0 comments on commit 754b5b8

Please sign in to comment.