Skip to content

Commit

Permalink
Fix page size in local mode
Browse files Browse the repository at this point in the history
Instead of always using Runtime.getRuntime.availableProcessors
for the page size, we should respect the number of slots specified
by the driver, e.g. local[32]. Note that this commit only fixes
the issue for local modes.
  • Loading branch information
Andrew Or committed Oct 8, 2015
1 parent 82d275f commit 0e140c2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
48 changes: 35 additions & 13 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus): SparkEnv = {
SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master))
}

private[spark] def env: SparkEnv = _env
Expand Down Expand Up @@ -2570,25 +2570,29 @@ object SparkContext extends Logging {
res
}

/**
* The number of driver cores to use for execution in local mode, 0 otherwise.
*/
private[spark] def numDriverCores(master: String): Int = {
def convertToInt(threads: String): Int = {
if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt
}
master match {
case "local" => 1
case SparkMasterRegex.LOCAL_N_REGEX(threads) => convertToInt(threads)
case SparkMasterRegex.LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads)
case _ => 0 // driver is not used for execution
}
}

/**
* Create a task scheduler based on a given master URL.
* Return a 2-tuple of the scheduler backend and the task scheduler.
*/
private def createTaskScheduler(
sc: SparkContext,
master: String): (SchedulerBackend, TaskScheduler) = {
// Regular expression used for local[N] and local[*] master formats
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """spark://(.*)""".r
// Regular expression for connection to Mesos cluster by mesos:// or zk:// url
val MESOS_REGEX = """(mesos|zk)://.*""".r
// Regular expression for connection to Simr cluster
val SIMR_REGEX = """simr://(.*)""".r
import SparkMasterRegex._

// When running locally, don't try to re-execute tasks on failure.
val MAX_LOCAL_TASK_FAILURES = 1
Expand Down Expand Up @@ -2729,6 +2733,24 @@ object SparkContext extends Logging {
}
}

/**
* A collection of regexes for extracting information from the master string.
*/
private object SparkMasterRegex {
// Regular expression used for local[N] and local[*] master formats
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """spark://(.*)""".r
// Regular expression for connection to Mesos cluster by mesos:// or zk:// url
val MESOS_REGEX = """(mesos|zk)://.*""".r
// Regular expression for connection to Simr cluster
val SIMR_REGEX = """simr://(.*)""".r
}

/**
* A class encapsulating how to convert some type T to Writable. It stores both the Writable class
* corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion.
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ object SparkEnv extends Logging {
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus,
numCores: Int,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
Expand All @@ -199,6 +200,7 @@ object SparkEnv extends Logging {
port,
isDriver = true,
isLocal = isLocal,
numUsableCores = numCores,
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
Expand Down Expand Up @@ -238,8 +240,8 @@ object SparkEnv extends Logging {
port: Int,
isDriver: Boolean,
isLocal: Boolean,
numUsableCores: Int,
listenerBus: LiveListenerBus = null,
numUsableCores: Int = 0,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {

// Listener bus is only used on the driver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
outputCommitCoordinator = spy(new OutputCommitCoordinator(conf, isDriver = true))
// Use Mockito.spy() to maintain the default infrastructure everywhere else.
// This mocking allows us to control the coordinator responses in test cases.
SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator))
SparkEnv.createDriverEnv(conf, isLocal, listenerBus,
SparkContext.numDriverCores(master), Some(outputCommitCoordinator))
}
}
// Use Mockito.spy() to maintain the default infrastructure everywhere else
Expand Down

0 comments on commit 0e140c2

Please sign in to comment.