Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils}
private[spark] class TaskSchedulerImpl private[scheduler](
val sc: SparkContext,
val maxTaskFailures: Int,
blacklistTrackerOpt: Option[BlacklistTracker],
private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker],
isLocal: Boolean = false)
extends TaskScheduler with Logging
{
Expand Down Expand Up @@ -337,8 +337,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
}
}.getOrElse(offers)

// Randomly shuffle offers to avoid always placing tasks on the same set of workers.
val shuffledOffers = Random.shuffle(filteredOffers)
val shuffledOffers = shuffleOffers(filteredOffers)
// Build a list of tasks to assign to each worker.
val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = shuffledOffers.map(o => o.cores).toArray
Expand Down Expand Up @@ -375,6 +374,14 @@ private[spark] class TaskSchedulerImpl private[scheduler](
return tasks
}

/**
* Shuffle offers around to avoid always placing tasks on the same workers. Exposed to allow
* overriding in tests, so it can be deterministic.
*/
protected def shuffleOffers(offers: IndexedSeq[WorkerOffer]): IndexedSeq[WorkerOffer] = {
Random.shuffle(offers)
}

def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var failedExecutor: Option[String] = None
var reason: Option[ExecutorLossReason] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ private[spark] class TaskSetManager(
addPendingTask(i)
}

// Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
/**
* Track the set of locality levels which are valid given the tasks locality preferences and
* the set of currently available executors. This is updated as executors are added and removed.
* This allows a performance optimization, of skipping levels that aren't relevant (eg., skip
* PROCESS_LOCAL if no tasks could be run PROCESS_LOCAL for the current set of executors).
*/
var myLocalityLevels = computeValidLocalityLevels()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I was figuring out the purpose of this for what to put in the comment, I made a couple of observations:

  1. For each executor we add or remove, its an O(numExecutors) operation to update the locality levels. So overall its an O(numExecutors^2) to add a bunch. Minor on small clusters, but I wonder if this is an issue when you're using dynamic allocation and going up and down to 1000s of executors. Its all happening with a lock on the TaskSchedulerImpl too.

  2. Though we recompute valid locality levels as executors come and go, we do not as tasks complete. That's not a problem -- as offers come in, we still go through the right task lists. But it does make me wonder whether this business of updating the locality levels for the current set of executors is useful, and instead we should just always use all levels.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) does seem like an issue. I also mostly agree for (2), since the logic of avoiding unnecessarily waiting for delay timeouts is already handled (separately from the myLocalityLevels calculation) here. My only hesitation is that myLocalityLevels does allow avoiding the delay timeout in cases where there are tasks have constraints to run on executors that haven't been granted to the application, so that use case seems like it might merit keeping the code (also, if you agree, can you update the myLocalityLevels comment?). In any case I'd do this in a separate PR.

var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level

Expand Down Expand Up @@ -961,18 +966,18 @@ private[spark] class TaskSetManager(
private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY}
val levels = new ArrayBuffer[TaskLocality.TaskLocality]
if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 &&
if (!pendingTasksForExecutor.isEmpty &&
pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) {
levels += PROCESS_LOCAL
}
if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0 &&
if (!pendingTasksForHost.isEmpty &&
pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) {
levels += NODE_LOCAL
}
if (!pendingTasksWithNoPrefs.isEmpty) {
levels += NO_PREF
}
if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 &&
if (!pendingTasksForRack.isEmpty &&
pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) {
levels += RACK_LOCAL
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.scalatest.mock.MockitoSugar
import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.internal.Logging
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.ManualClock

class FakeSchedulerBackend extends SchedulerBackend {
def start() {}
Expand Down Expand Up @@ -819,4 +819,89 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(!taskScheduler.hasExecutorsAliveOnHost("host0"))
assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty)
}

test("Locality should be used for bulk offers even with delay scheduling off") {
val conf = new SparkConf()
.set("spark.locality.wait", "0")
sc = new SparkContext("local", "TaskSchedulerImplSuite", conf)
// we create a manual clock just so we can be sure the clock doesn't advance at all in this test
val clock = new ManualClock()

// We customize the task scheduler just to let us control the way offers are shuffled, so we
// can be sure we try both permutations, and to control the clock on the tasksetmanager.
val taskScheduler = new TaskSchedulerImpl(sc) {
override def shuffleOffers(offers: IndexedSeq[WorkerOffer]): IndexedSeq[WorkerOffer] = {
// Don't shuffle the offers around for this test. Instead, we'll just pass in all
// the permutations we care about directly.
offers
}
override def createTaskSetManager(taskSet: TaskSet, maxTaskFailures: Int): TaskSetManager = {
new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt, clock)
}
}
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
taskScheduler.initialize(new FakeSchedulerBackend)

// Make two different offers -- one in the preferred location, one that is not.
val offers = IndexedSeq(
WorkerOffer("exec1", "host1", 1),
WorkerOffer("exec2", "host2", 1)
)
Seq(false, true).foreach { swapOrder =>
// Submit a taskset with locality preferences.
val taskSet = FakeTask.createTaskSet(
1, stageId = 1, stageAttemptId = 0, Seq(TaskLocation("host1", "exec1")))
taskScheduler.submitTasks(taskSet)
val shuffledOffers = if (swapOrder) offers.reverse else offers
// Regardless of the order of the offers (after the task scheduler shuffles them), we should
// always take advantage of the local offer.
val taskDescs = taskScheduler.resourceOffers(shuffledOffers).flatten
withClue(s"swapOrder = $swapOrder") {
assert(taskDescs.size === 1)
assert(taskDescs.head.executorId === "exec1")
}
}
}

test("With delay scheduling off, tasks can be run at any locality level immediately") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry last thing: just realized -- does this test need to first submit a local resource offer? That makes sure that the local executor is considered alive. Otherwise, process local won't be in the set of allowed locality levels because of the code here: https://github.com/apache/spark/pull/16376/files#diff-bad3987c83bd22d46416d3dd9d208e76R966, which makes this test somewhat less effective if I understand correctly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are absolutely right. I've updated and also added a check to make sure tsm includes lower locality levels.

val conf = new SparkConf()
.set("spark.locality.wait", "0")
sc = new SparkContext("local", "TaskSchedulerImplSuite", conf)

// we create a manual clock just so we can be sure the clock doesn't advance at all in this test
val clock = new ManualClock()
val taskScheduler = new TaskSchedulerImpl(sc) {
override def createTaskSetManager(taskSet: TaskSet, maxTaskFailures: Int): TaskSetManager = {
new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt, clock)
}
}
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
taskScheduler.initialize(new FakeSchedulerBackend)
// make an offer on the preferred host so the scheduler knows its alive. This is necessary
// so that the taskset knows that it *could* take advantage of locality.
taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec1", "host1", 1)))

// Submit a taskset with locality preferences.
val taskSet = FakeTask.createTaskSet(
1, stageId = 1, stageAttemptId = 0, Seq(TaskLocation("host1", "exec1")))
taskScheduler.submitTasks(taskSet)
val tsm = taskScheduler.taskSetManagerForAttempt(1, 0).get
// make sure we've setup our test correctly, so that the taskset knows it *could* use local
// offers.
assert(tsm.myLocalityLevels.contains(TaskLocality.NODE_LOCAL))
// make an offer on a non-preferred location. Since the delay is 0, we should still schedule
// immediately.
val taskDescs =
taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec2", "host2", 1))).flatten
assert(taskDescs.size === 1)
assert(taskDescs.head.executorId === "exec2")
}
}