diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3422a5f204b12..89b4cab88109d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1122,6 +1122,25 @@ class DAGScheduler( } } + private def postTaskEnd(event: CompletionEvent): Unit = { + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulators(event.accumUpdates) + } catch { + case NonFatal(e) => + val taskId = event.taskInfo.taskId + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + listenerBus.post(SparkListenerTaskEnd(event.task.stageId, event.task.stageAttemptId, + Utils.getFormattedClassName(event.task), event.reason, event.taskInfo, taskMetrics)) + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -1138,34 +1157,36 @@ class DAGScheduler( event.taskInfo.attemptNumber, // this is a task attempt number event.reason) - // Reconstruct task metrics. Note: this may be null if the task has failed. - val taskMetrics: TaskMetrics = - if (event.accumUpdates.nonEmpty) { - try { - TaskMetrics.fromAccumulators(event.accumUpdates) - } catch { - case NonFatal(e) => - logError(s"Error when attempting to reconstruct metrics for task $taskId", e) - null - } - } else { - null - } - - // The stage may have already finished when we get this event -- eg. maybe it was a - // speculative task. It is important that we send the TaskEnd event in any case, so listeners - // are properly notified and can chose to handle it. For instance, some listeners are - // doing their own accounting and if they don't get the task end event they think - // tasks are still running when they really aren't. - listenerBus.post(SparkListenerTaskEnd( - stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics)) - if (!stageIdToStage.contains(task.stageId)) { + // The stage may have already finished when we get this event -- eg. maybe it was a + // speculative task. It is important that we send the TaskEnd event in any case, so listeners + // are properly notified and can chose to handle it. For instance, some listeners are + // doing their own accounting and if they don't get the task end event they think + // tasks are still running when they really aren't. + postTaskEnd(event) + // Skip all the actions if the stage has been cancelled. return } val stage = stageIdToStage(task.stageId) + + // Make sure the task's accumulators are updated before any other processing happens, so that + // we can post a task end event before any jobs or stages are updated. The accumulators are + // only updated in certain cases. + event.reason match { + case Success => + stage match { + case rs: ResultStage if rs.activeJob.isEmpty => + // Ignore update if task's job has finished. + case _ => + updateAccumulators(event) + } + case _: ExceptionFailure => updateAccumulators(event) + case _ => + } + postTaskEnd(event) + event.reason match { case Success => task match { @@ -1176,7 +1197,6 @@ class DAGScheduler( resultStage.activeJob match { case Some(job) => if (!job.finished(rt.outputId)) { - updateAccumulators(event) job.finished(rt.outputId) = true job.numFinished += 1 // If the whole job has finished, remove it @@ -1203,7 +1223,6 @@ class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] - updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) @@ -1374,8 +1393,7 @@ class DAGScheduler( // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case exceptionFailure: ExceptionFailure => - // Tasks failed with exceptions might still have accumulator updates. - updateAccumulators(event) + // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 453be26ed8d0c..3b5df657d45cf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -2346,6 +2346,36 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou (Success, 1))) } + test("task end event should have updated accumulators (SPARK-20342)") { + val tasks = 10 + + val accumId = new AtomicLong() + val foundCount = new AtomicLong() + val listener = new SparkListener() { + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + event.taskInfo.accumulables.find(_.id == accumId.get).foreach { _ => + foundCount.incrementAndGet() + } + } + } + sc.addSparkListener(listener) + + // Try a few times in a loop to make sure. This is not guaranteed to fail when the bug exists, + // but it should at least make the test flaky. If the bug is fixed, this should always pass. + (1 to 10).foreach { i => + foundCount.set(0L) + + val accum = sc.longAccumulator(s"accum$i") + accumId.set(accum.id) + + sc.parallelize(1 to tasks, tasks).foreach { _ => + accum.add(1L) + } + sc.listenerBus.waitUntilEmpty(1000) + assert(foundCount.get() === tasks) + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID.