Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Apr 19, 2019
1 parent eb427f7 commit b148c6f
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 15 deletions.
Expand Up @@ -1394,7 +1394,8 @@ private[spark] class DAGScheduler(
// finished. Here we notify the task scheduler to skip running tasks for the same partition,
// to save resource.
if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
taskScheduler.notifyPartitionCompletion(
stageId, task.partitionId, event.taskInfo.duration)
}

task match {
Expand Down
Expand Up @@ -155,9 +155,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
}

def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = {
// This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want
// DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's
// synchronized and may hurt the throughput of the scheduler.
def enqueuePartitionCompletionNotification(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
scheduler.handlePartitionCompleted(stageId, partitionId)
scheduler.handlePartitionCompleted(stageId, partitionId, taskDuration)
})
}

Expand Down
Expand Up @@ -70,7 +70,7 @@ private[spark] trait TaskScheduler {

// Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed
// and they can skip running tasks for it.
def notifyPartitionCompletion(stageId: Int, partitionId: Int)
def notifyPartitionCompletion(stageId: Int, partitionId: Int, taskDuration: Long)

// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
Expand Down
Expand Up @@ -301,8 +301,9 @@ private[spark] class TaskSchedulerImpl(
}
}

override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId)
override def notifyPartitionCompletion(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId, taskDuration)
}

/**
Expand Down Expand Up @@ -652,9 +653,10 @@ private[spark] class TaskSchedulerImpl(
*/
private[scheduler] def handlePartitionCompleted(
stageId: Int,
partitionId: Int) = synchronized {
partitionId: Int,
taskDuration: Long) = synchronized {
taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
tsm.markPartitionCompleted(partitionId)
tsm.markPartitionCompleted(partitionId, taskDuration)
})
}

Expand Down
Expand Up @@ -816,12 +816,11 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}

private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
private[scheduler] def markPartitionCompleted(partitionId: Int, taskDuration: Long): Unit = {
partitionToIndex.get(partitionId).foreach { index =>
if (!successful(index)) {
if (speculationEnabled && !isZombie) {
// The task is skipped, its duration should be 0.
successfulTaskDurations.insert(0)
successfulTaskDurations.insert(taskDuration)
}
tasksSuccessful += 1
successful(index) = true
Expand Down
Expand Up @@ -158,7 +158,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
override def notifyPartitionCompletion(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
val tasks = ts.tasks.filter(_.partitionId == partitionId)
assert(tasks.length == 1)
Expand Down Expand Up @@ -668,7 +669,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
throw new UnsupportedOperationException
}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
override def notifyPartitionCompletion(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
throw new UnsupportedOperationException
}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
Expand Down
Expand Up @@ -84,7 +84,8 @@ private class DummyTaskScheduler extends TaskScheduler {
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {}
override def notifyPartitionCompletion(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
Expand Down
Expand Up @@ -1394,7 +1394,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg

val taskSetManager = sched.taskSetManagerForAttempt(0, 0).get
assert(taskSetManager.runningTasks === 8)
taskSetManager.markPartitionCompleted(8)
taskSetManager.markPartitionCompleted(8, 0)
assert(!taskSetManager.successfulTaskDurations.isEmpty())
taskSetManager.checkSpeculatableTasks(0)
}
Expand Down

0 comments on commit b148c6f

Please sign in to comment.