From b148c6f119be4b0a5da0c2600add506cd930a647 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Apr 2019 09:00:01 +0800 Subject: [PATCH] address comments --- .../org/apache/spark/scheduler/DAGScheduler.scala | 3 ++- .../org/apache/spark/scheduler/TaskResultGetter.scala | 8 ++++++-- .../org/apache/spark/scheduler/TaskScheduler.scala | 2 +- .../org/apache/spark/scheduler/TaskSchedulerImpl.scala | 10 ++++++---- .../org/apache/spark/scheduler/TaskSetManager.scala | 5 ++--- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 6 ++++-- .../spark/scheduler/ExternalClusterManagerSuite.scala | 3 ++- .../apache/spark/scheduler/TaskSetManagerSuite.scala | 2 +- 8 files changed, 24 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 98aa4772b9dd4..af11a319083d5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1394,7 +1394,8 @@ private[spark] class DAGScheduler( // finished. Here we notify the task scheduler to skip running tasks for the same partition, // to save resource. if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { - taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + taskScheduler.notifyPartitionCompletion( + stageId, task.partitionId, event.taskInfo.duration) } task match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 72f4a4128bb22..09c4d9b5bce04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -155,9 +155,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } } - def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = { + // This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want + // DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's + // synchronized and may hurt the throughput of the scheduler. + def enqueuePartitionCompletionNotification( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = { getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions { - scheduler.handlePartitionCompleted(stageId, partitionId) + scheduler.handlePartitionCompleted(stageId, partitionId, taskDuration) }) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index bfdbf0217210a..1862e16824277 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -70,7 +70,7 @@ private[spark] trait TaskScheduler { // Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed // and they can skip running tasks for it. - def notifyPartitionCompletion(stageId: Int, partitionId: Int) + def notifyPartitionCompletion(stageId: Int, partitionId: Int, taskDuration: Long) // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 7ac6d8131b42d..7e820c32fa78d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -301,8 +301,9 @@ private[spark] class TaskSchedulerImpl( } } - override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { - taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId) + override def notifyPartitionCompletion( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = { + taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId, taskDuration) } /** @@ -652,9 +653,10 @@ private[spark] class TaskSchedulerImpl( */ private[scheduler] def handlePartitionCompleted( stageId: Int, - partitionId: Int) = synchronized { + partitionId: Int, + taskDuration: Long) = synchronized { taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm => - tsm.markPartitionCompleted(partitionId) + tsm.markPartitionCompleted(partitionId, taskDuration) }) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index ef9cb528f3e64..b3aa814537500 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -816,12 +816,11 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + private[scheduler] def markPartitionCompleted(partitionId: Int, taskDuration: Long): Unit = { partitionToIndex.get(partitionId).foreach { index => if (!successful(index)) { if (speculationEnabled && !isZombie) { - // The task is skipped, its duration should be 0. - successfulTaskDurations.insert(0) + successfulTaskDurations.insert(taskDuration) } tasksSuccessful += 1 successful(index) = true diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 15091c672f1ef..c8ae834e01e19 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -158,7 +158,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def killAllTaskAttempts( stageId: Int, interruptThread: Boolean, reason: String): Unit = {} - override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { + override def notifyPartitionCompletion( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = { taskSets.filter(_.stageId == stageId).lastOption.foreach { ts => val tasks = ts.tasks.filter(_.partitionId == partitionId) assert(tasks.length == 1) @@ -668,7 +669,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi stageId: Int, interruptThread: Boolean, reason: String): Unit = { throw new UnsupportedOperationException } - override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { + override def notifyPartitionCompletion( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = { throw new UnsupportedOperationException } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index ead34e535723f..347064dc9aadf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -84,7 +84,8 @@ private class DummyTaskScheduler extends TaskScheduler { taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def killAllTaskAttempts( stageId: Int, interruptThread: Boolean, reason: String): Unit = {} - override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {} + override def notifyPartitionCompletion( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 8d8d7994964b8..0666bc335abac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1394,7 +1394,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSetManager = sched.taskSetManagerForAttempt(0, 0).get assert(taskSetManager.runningTasks === 8) - taskSetManager.markPartitionCompleted(8) + taskSetManager.markPartitionCompleted(8, 0) assert(!taskSetManager.successfulTaskDurations.isEmpty()) taskSetManager.checkSpeculatableTasks(0) }