Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-1937: fix issue with task locality #892

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,14 @@ private[spark] class TaskSchedulerImpl(
SparkEnv.set(sc.env)

// Mark each slave as alive and remember its hostname
// Also track if new executor is added
var newExecAvail = false
for (o <- offers) {
executorIdToHost(o.executorId) = o.host
if (!executorsByHost.contains(o.host)) {
executorsByHost(o.host) = new HashSet[String]()
executorAdded(o.executorId, o.host)
newExecAvail = true
}
}

Expand All @@ -227,12 +230,15 @@ private[spark] class TaskSchedulerImpl(
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
taskSet.parent.name, taskSet.name, taskSet.runningTasks))
if (newExecAvail) {
taskSet.executorAdded()
}
}

// Take each TaskSet in our scheduling order, and then offer it each node in increasing order
// of locality levels so that it gets a chance to launch local tasks on all of them.
var launchedTask = false
for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
do {
launchedTask = false
for (i <- 0 until shuffledOffers.size) {
Expand Down
43 changes: 30 additions & 13 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ private[spark] class TaskSetManager(
private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]

// Set containing pending tasks with no locality preferences.
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
var pendingTasksWithNoPrefs = new ArrayBuffer[Int]

// Set containing all pending tasks (also used as a stack, as above).
val allPendingTasks = new ArrayBuffer[Int]
Expand Down Expand Up @@ -153,8 +153,8 @@ private[spark] class TaskSetManager(
}

// Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
val myLocalityLevels = computeValidLocalityLevels()
val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
var myLocalityLevels = computeValidLocalityLevels()
var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level

// Delay scheduling variables: we keep track of our current locality level and the time we
// last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
Expand All @@ -181,16 +181,14 @@ private[spark] class TaskSetManager(
var hadAliveLocations = false
for (loc <- tasks(index).preferredLocations) {
for (execId <- loc.executorId) {
if (sched.isExecutorAlive(execId)) {
addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
hadAliveLocations = true
}
addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
}
if (sched.hasExecutorsAliveOnHost(loc.host)) {
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
for (rack <- sched.getRackForHost(loc.host)) {
addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
}
hadAliveLocations = true
}
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
for (rack <- sched.getRackForHost(loc.host)) {
addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
hadAliveLocations = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess technically we might have no hosts in this rack, but right now our TaskScheduler doesn't track that. Maybe we should open another JIRA to track it. I can imagine this happening in really large clusters.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the TaskScheduler should provide something like "hasHostOnRack", and we have to check that before set hadAliveLocations to true?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we can do it in another JIRA.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure :-)

}
}
Expand Down Expand Up @@ -725,10 +723,12 @@ private[spark] class TaskSetManager(
private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
val levels = new ArrayBuffer[TaskLocality.TaskLocality]
if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 &&
pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) {
levels += PROCESS_LOCAL
}
if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0 &&
pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) {
levels += NODE_LOCAL
}
if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
Expand All @@ -738,4 +738,21 @@ private[spark] class TaskSetManager(
logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
levels.toArray
}

// Re-compute pendingTasksWithNoPrefs since new preferred locations may become available
def executorAdded() {
def newLocAvail(index: Int): Boolean = {
for (loc <- tasks(index).preferredLocations) {
if (sched.hasExecutorsAliveOnHost(loc.host) ||
sched.getRackForHost(loc.host).isDefined) {
return true
}
}
false
}
logInfo("Re-computing pending task lists.")
pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_))
myLocalityLevels = computeValidLocalityLevels()
localityWaits = myLocalityLevels.map(getLocalityWait)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
override def isExecutorAlive(execId: String): Boolean = executors.contains(execId)

override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host)

def addExecutor(execId: String, host: String) {
executors.put(execId, host)
}
}

class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
Expand Down Expand Up @@ -384,6 +388,36 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(sched.taskSetsFailed.contains(taskSet.id))
}

test("new executors get added") {
sc = new SparkContext("local", "test")
val sched = new FakeTaskScheduler(sc)
val taskSet = FakeTask.createTaskSet(4,
Seq(TaskLocation("host1", "execA")),
Seq(TaskLocation("host1", "execB")),
Seq(TaskLocation("host2", "execC")),
Seq())
val clock = new FakeClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// All tasks added to no-pref list since no preferred location is available
assert(manager.pendingTasksWithNoPrefs.size === 4)
// Only ANY is valid
assert(manager.myLocalityLevels.sameElements(Array(ANY)))
// Add a new executor
sched.addExecutor("execD", "host1")
manager.executorAdded()
// Task 0 and 1 should be removed from no-pref list
assert(manager.pendingTasksWithNoPrefs.size === 2)
// Valid locality should contain NODE_LOCAL and ANY
assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, ANY)))
// Add another executor
sched.addExecutor("execC", "host2")
manager.executorAdded()
// No-pref list now only contains task 3
assert(manager.pendingTasksWithNoPrefs.size === 1)
// Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY
assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
}

def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)
Expand Down