Skip to content

Commit

Permalink
Merge pull request #11 from markhamstra/SPARK-979
Browse files Browse the repository at this point in the history
SKIPME SPARK-979 Randomize order of offers.
  • Loading branch information
jhartlaub committed May 31, 2014
2 parents 55c78ab + 9153973 commit 923c5ba
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.concurrent.duration._
import scala.util.Random

import org.apache.spark._
import org.apache.spark.TaskState.TaskState
Expand Down Expand Up @@ -207,9 +208,11 @@ private[spark] class TaskSchedulerImpl(
}
}

// Build a list of tasks to assign to each worker
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = offers.map(o => o.cores).toArray
// Randomly shuffle offers to avoid always placing tasks on the same set of workers.
val shuffledOffers = Random.shuffle(offers)
// Build a list of tasks to assign to each worker.
val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = shuffledOffers.map(o => o.cores).toArray
val sortedTaskSets = rootPool.getSortedTaskSetQueue()
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
Expand All @@ -222,9 +225,9 @@ private[spark] class TaskSchedulerImpl(
for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
do {
launchedTask = false
for (i <- 0 until offers.size) {
val execId = offers(i).executorId
val host = offers(i).host
for (i <- 0 until shuffledOffers.size) {
val execId = shuffledOffers(i).executorId
val host = shuffledOffers(i).host
for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
tasks(i) += task
val tid = task.taskId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.api.python

import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.api.python.PythonRDD

import java.io.{ByteArrayOutputStream, DataOutputStream}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ import scala.collection.mutable.ArrayBuffer

import java.util.Properties

class FakeSchedulerBackend extends SchedulerBackend {
def start() {}
def stop() {}
def reviveOffers() {}
def defaultParallelism() = 1
}

class FakeTaskSetManager(
initPriority: Int,
initStageId: Int,
Expand Down Expand Up @@ -107,7 +114,8 @@ class FakeTaskSetManager(

class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {

def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = {
def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl,
taskSet: TaskSet): FakeTaskSetManager = {
new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
}

Expand Down Expand Up @@ -135,10 +143,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
test("FIFO Scheduler Test") {
sc = new SparkContext("local", "ClusterSchedulerSuite")
val clusterScheduler = new TaskSchedulerImpl(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
val taskSet = FakeTask.createTaskSet(1)

val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
Expand All @@ -162,10 +167,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
test("Fair Scheduler Test") {
sc = new SparkContext("local", "ClusterSchedulerSuite")
val clusterScheduler = new TaskSchedulerImpl(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
val taskSet = FakeTask.createTaskSet(1)

val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.scheduler.allocation.file", xmlPath)
Expand Down Expand Up @@ -219,10 +221,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
test("Nested Pool Test") {
sc = new SparkContext("local", "ClusterSchedulerSuite")
val clusterScheduler = new TaskSchedulerImpl(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
val taskSet = FakeTask.createTaskSet(1)

val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
Expand Down Expand Up @@ -265,4 +264,35 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
checkTaskSetId(rootPool, 6)
checkTaskSetId(rootPool, 2)
}

test("Scheduler does not always schedule tasks on the same workers") {
sc = new SparkContext("local", "ClusterSchedulerSuite")
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
var dagScheduler = new DAGScheduler(taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorGained(execId: String, host: String) {}
}

val numFreeCores = 1
val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores),
new WorkerOffer("executor1", "host1", numFreeCores))
// Repeatedly try to schedule a 1-task job, and make sure that it doesn't always
// get scheduled on the same executor. While there is a chance this test will fail
// because the task randomly gets placed on the first executor all 1000 times, the
// probability of that happening is 2^-1000 (so sufficiently small to be considered
// negligible).
val numTrials = 1000
val selectedExecutorIds = 1.to(numTrials).map { _ =>
val taskSet = FakeTask.createTaskSet(1)
taskScheduler.submitTasks(taskSet)
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
assert(1 === taskDescriptions.length)
taskDescriptions(0).executorId
}
var count = selectedExecutorIds.count(_ == workerOffers(0).executorId)
assert(count > 0)
assert(count < numTrials)
}
}
16 changes: 16 additions & 0 deletions core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,19 @@ class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int

override def preferredLocations: Seq[TaskLocation] = prefLocs
}

object FakeTask {
/**
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred
* locations for each task (given as varargs) if this sequence is not empty.
*/
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
throw new IllegalArgumentException("Wrong number of task locations")
}
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
}
new TaskSet(tasks, 0, 0, 0, null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("TaskSet with no preferences") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val taskSet = FakeTask.createTaskSet(1)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)

// Offer a host with no CPUs
Expand All @@ -115,7 +115,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("multiple offers with no preferences") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(3)
val taskSet = FakeTask.createTaskSet(3)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)

// First three offers should all find tasks
Expand Down Expand Up @@ -146,7 +146,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("basic delay scheduling") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
val taskSet = createTaskSet(4,
val taskSet = FakeTask.createTaskSet(4,
Seq(TaskLocation("host1", "exec1")),
Seq(TaskLocation("host2", "exec2")),
Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")),
Expand Down Expand Up @@ -191,7 +191,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc,
("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
val taskSet = createTaskSet(5,
val taskSet = FakeTask.createTaskSet(5,
Seq(TaskLocation("host1")),
Seq(TaskLocation("host2")),
Seq(TaskLocation("host2")),
Expand Down Expand Up @@ -230,7 +230,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("delay scheduling with failed hosts") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
val taskSet = createTaskSet(3,
val taskSet = FakeTask.createTaskSet(3,
Seq(TaskLocation("host1")),
Seq(TaskLocation("host2")),
Seq(TaskLocation("host3"))
Expand Down Expand Up @@ -262,7 +262,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("task result lost") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val taskSet = FakeTask.createTaskSet(1)
val clock = new FakeClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)

Expand All @@ -279,7 +279,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("repeated failures lead to task set abortion") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val taskSet = FakeTask.createTaskSet(1)
val clock = new FakeClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)

Expand All @@ -299,21 +299,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
}
}


/**
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred
* locations for each task (given as varargs) if this sequence is not empty.
*/
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
throw new IllegalArgumentException("Wrong number of task locations")
}
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
}
new TaskSet(tasks, 0, 0, 0, null)
}

def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)
Expand Down

0 comments on commit 923c5ba

Please sign in to comment.