Skip to content

Commit

Permalink
Use stubs instead of mocks for DAGSchedulerSuite.
Browse files Browse the repository at this point in the history
  • Loading branch information
Stephen Haberman committed Feb 9, 2013
1 parent 9cfa068 commit 921be76
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 527 deletions.
18 changes: 4 additions & 14 deletions core/src/main/scala/spark/MapOutputTracker.scala
Expand Up @@ -38,9 +38,10 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
}
}

private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
private[spark] class MapOutputTracker extends Logging {

val timeout = 10.seconds
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _

var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]

Expand All @@ -53,24 +54,13 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
var cacheGeneration = generation
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]

val actorName: String = "MapOutputTracker"
var trackerActor: ActorRef = if (isDriver) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor")
actor
} else {
val ip = System.getProperty("spark.driver.host", "localhost")
val port = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}

val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)

// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
val timeout = 10.seconds
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
Expand Down
31 changes: 24 additions & 7 deletions core/src/main/scala/spark/SparkEnv.scala
@@ -1,7 +1,6 @@
package spark

import akka.actor.ActorSystem
import akka.actor.ActorSystemImpl
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider

import serializer.Serializer
Expand Down Expand Up @@ -83,11 +82,23 @@ object SparkEnv extends Logging {
}

val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")

def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
actorSystem.actorFor(url)
}
}

val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val blockManagerMaster = new BlockManagerMaster(
actorSystem, isDriver, isLocal, driverIp, driverPort)
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new spark.storage.BlockManagerMasterActor(isLocal)))
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)

val connectionManager = blockManager.connectionManager
Expand All @@ -99,7 +110,12 @@ object SparkEnv extends Logging {

val cacheManager = new CacheManager(blockManager)

val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = new MapOutputTracker()
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerActor(mapOutputTracker))

val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
Expand Down Expand Up @@ -137,4 +153,5 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir)
}

}
4 changes: 2 additions & 2 deletions core/src/main/scala/spark/storage/BlockManager.scala
Expand Up @@ -88,7 +88,7 @@ class BlockManager(

val host = System.getProperty("spark.hostname", Utils.localHostName())

val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)

// Pending reregistration action being executed asynchronously or null if none
Expand Down Expand Up @@ -946,7 +946,7 @@ class BlockManager(
heartBeatTask.cancel()
}
connectionManager.stop()
master.actorSystem.stop(slaveActor)
actorSystem.stop(slaveActor)
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
Expand Down
24 changes: 2 additions & 22 deletions core/src/main/scala/spark/storage/BlockManagerMaster.scala
Expand Up @@ -15,32 +15,12 @@ import akka.util.duration._

import spark.{Logging, SparkException, Utils}

private[spark] class BlockManagerMaster(
val actorSystem: ActorSystem,
isDriver: Boolean,
isLocal: Boolean,
driverIp: String,
driverPort: Int)
extends Logging {
private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {

val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt

val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager"

val timeout = 10.seconds
var driverActor: ActorRef = {
if (isDriver) {
val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
name = DRIVER_AKKA_ACTOR_NAME)
logInfo("Registered BlockManagerMaster Actor")
driverActor
} else {
val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME)
logInfo("Connecting to BlockManagerMaster: " + url)
actorSystem.actorFor(url)
}
}

/** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) {
Expand All @@ -59,7 +39,7 @@ private[spark] class BlockManagerMaster(

/** Register the BlockManager's id with the driver. */
def registerBlockManager(
blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager")
tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
logInfo("Registered BlockManager")
Expand Down
5 changes: 2 additions & 3 deletions core/src/main/scala/spark/storage/ThreadingTest.scala
Expand Up @@ -75,9 +75,8 @@ private[spark] object ThreadingTest {
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer
val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort)
val blockManagerMaster = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
Expand Down
62 changes: 32 additions & 30 deletions core/src/test/scala/spark/MapOutputTrackerSuite.scala
Expand Up @@ -31,13 +31,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {

test("master start and stop") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker(actorSystem, true)
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.stop()
}

test("master register and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker(actorSystem, true)
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
Expand All @@ -55,7 +57,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {

test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker(actorSystem, true)
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
Expand All @@ -77,35 +80,34 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
}

test("remote fetch") {
try {
System.clearProperty("spark.driver.host") // In case some previous test had set it
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
System.setProperty("spark.driver.port", boundPort.toString)
val masterTracker = new MapOutputTracker(actorSystem, true)
val slaveTracker = new MapOutputTracker(actorSystem, false)
masterTracker.registerShuffle(10, 1)
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
val masterTracker = new MapOutputTracker()
masterTracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(masterTracker)))

val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", "localhost", 0)
val slaveTracker = new MapOutputTracker()
slaveTracker.trackerActor = slaveSystem.actorFor("akka://spark@localhost:" + boundPort)

masterTracker.registerShuffle(10, 1)
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }

val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000), size1000)))
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000), size1000)))

masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }

// failure should be cached
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
} finally {
System.clearProperty("spark.driver.port")
}
// failure should be cached
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
}
}

0 comments on commit 921be76

Please sign in to comment.