From ae9a8cb78635386a3c57ef71e27138196e671f2f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 12 Apr 2021 12:21:25 -0700 Subject: [PATCH] Use TaskSchedulingPlugin to suggest task scheduling. --- .../spark/internal/config/package.scala | 9 +++ .../scheduler/TaskSchedulingPlugin.scala | 56 +++++++++++++++++ .../spark/scheduler/TaskSetManager.scala | 17 +++++- .../spark/scheduler/TaskSetManagerSuite.scala | 61 +++++++++++++++++++ 4 files changed, 140 insertions(+), 3 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskSchedulingPlugin.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 1a18856e4156c..54a3f5c4f4c28 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1502,6 +1502,15 @@ package object config { .booleanConf .createWithDefault(true) + private[spark] val TASK_SCHEDULING_PLUGIN_CLASSNAME = + ConfigBuilder("spark.task.scheduling.pluginClassName") + .doc("The classname of the plugin used for providing scheduling suggestion to Spark task " + + "scheduler. The class must extend `TaskSchedulingPlugin` interface.") + .version("3.2.0") + .internal() + .stringConf + .createOptional + private[spark] val STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE = ConfigBuilder("spark.storage.localDiskByExecutors.cacheSize") .doc("The max number of executors for which the local dirs are stored. This size is " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulingPlugin.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulingPlugin.scala new file mode 100644 index 0000000000000..d4fe24de0d775 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulingPlugin.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +/** + * This trait provides a plugin interface for suggesting task scheduling to Spark + * scheduler. + */ +private[spark] trait TaskSchedulingPlugin { + + /** + * Ranks the given Spark tasks waiting for scheduling for the given executor + * offer. That is said, the head of returned task indexes points to mostly preferred + * task to be scheduled on the given executor. Note that the returned is index offsets + * instead of indexes. For example, if the given task indexes are [1, 2, 3], and the + * plugin returns [1, 2, 0], it means the ranked task indexes are actually [2, 3, 1]. + * + * @param tasks The full list of tasks + * @param taskIndexes The indexes of tasks eligible for scheduling on the executor/host. + * @return The index offsets of tasks, ranked by the preference of scheduling. + */ + def rankTasks( + execId: String, host: String, tasks: Seq[Task[_]], taskIndexes: Seq[Int]): Seq[Int] + + /** + * Spark scheduler takes the ranks of tasks returned by `rankTasks`. Once + * the scheduler decides which task to be actually scheduled, it will call + * this method to inform the plugin. Note that it is possible that the + * scheduler does not choose top-1 ranked task. The plugin may decide what + * action is needed if it is happening. + */ + def informScheduledTask(message: TaskScheduledResult): Unit +} + +private[spark] trait TaskScheduledResult { + def scheduledTask: Task[_] + def scheduledTaskIndex: Int +} + +private[spark] case class TaskWaitingForSchedule(scheduledTask: Task[_], scheduledTaskIndex: Int) + extends TaskScheduledResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3b72103f9930d..c295dd4f0aed8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -200,6 +200,10 @@ private[spark] class TaskSetManager( t.epoch = epoch } + val schedulingPlungin = conf.get(config.TASK_SCHEDULING_PLUGIN_CLASSNAME).map { plugin => + Utils.loadExtensions(classOf[TaskSchedulingPlugin], Seq(plugin), conf).head + } + // Add all our tasks to the pending lists. We do this in reverse order // of task index so that tasks with low indices get launched first. addPendingTasks() @@ -299,9 +303,11 @@ private[spark] class TaskSetManager( host: String, list: ArrayBuffer[Int], speculative: Boolean = false): Option[Int] = { - var indexOffset = list.size - while (indexOffset > 0) { - indexOffset -= 1 + // Gets preferred task ranking. Otherwise, dequeue from the tail of the list. + val rankedIndexOffsets = schedulingPlungin.map(_.rankTasks(execId, host, tasks, list)) + .getOrElse(Range(list.size - 1, -1, -1)) + + rankedIndexOffsets.foreach { indexOffset => val index = list(indexOffset) if (!isTaskExcludededOnExecOrNode(index, execId, host) && !(speculative && hasAttemptOnHost(index, host))) { @@ -363,6 +369,11 @@ private[spark] class TaskSetManager( if (speculative && task.isDefined) { speculatableTasks -= task.get } + // Let the scheduling plugin know which task is chosen. + task.foreach { taskIndex => + schedulingPlungin.map( + _.informScheduledTask(TaskWaitingForSchedule(tasks(taskIndex), taskIndex))) + } task } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 3841425fa5ae2..624b5145537d0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -2242,6 +2242,45 @@ class TaskSetManagerSuite // After 3s have elapsed now the task is marked as speculative task assert(sched.speculativeTasks.size == 1) } + + + test("SPARK-35022: TaskSet with scheduling plugin") { + sc = new SparkContext("local", "test") + sc.conf.set(config.TASK_SCHEDULING_PLUGIN_CLASSNAME, classOf[TestSchedulingPlugin].getName) + + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(5) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdates = taskSet.tasks.head.metrics.internalAccums + + // Offer a host with NO_PREF as the constraint, + // we should get a nopref task immediately since that's what we only have + val taskOption1 = manager.resourceOffer("exec1", "host1", NO_PREF)._1 + assert(taskOption1.isDefined) + + clock.advance(1) + + // `TestSchedulingPlugin` asks to schedule the task with largest task index. + val scheduledTask1 = taskOption1.get + assert(scheduledTask1.index == 4) + + // Tell it the task has finished + manager.handleSuccessfulTask(scheduledTask1.taskId, + createTaskResult(scheduledTask1.taskId.toInt, accumUpdates)) + assert(sched.endedTasks(scheduledTask1.index) === Success) + + val taskOption2 = manager.resourceOffer("exec1", "host1", NO_PREF)._1 + assert(taskOption2.isDefined) + + clock.advance(1) + val scheduledTask2 = taskOption2.get + + assert(scheduledTask2.index == 3) + manager.handleSuccessfulTask(scheduledTask2.taskId, + createTaskResult(scheduledTask2.taskId.toInt, accumUpdates)) + assert(sched.endedTasks(scheduledTask2.index) === Success) + } } class FakeLongTasks(stageId: Int, partitionId: Int) extends FakeTask(stageId, partitionId) { @@ -2253,3 +2292,25 @@ class FakeLongTasks(stageId: Int, partitionId: Int) extends FakeTask(stageId, pa 0 } } + +class TestSchedulingPlugin extends TaskSchedulingPlugin { + private var topRanked: Int = -1 + + override def rankTasks( + execId: String, host: String, tasks: Seq[Task[_]], taskIndexes: Seq[Int]): Seq[Int] = { + if (taskIndexes.isEmpty) { + topRanked = -1 + taskIndexes + } else { + // Tells `TaskSetManager` to schedule the task at largest task index. + topRanked = taskIndexes(0) + Seq(0) ++ Range(taskIndexes.size - 1, 0, -1) + } + } + + override def informScheduledTask(message: TaskScheduledResult): Unit = { + if (topRanked != -1 && topRanked != message.scheduledTaskIndex) { + throw new IllegalStateException(s"scheduled task index must be ${message.scheduledTaskIndex}") + } + } +}