Skip to content

Commit

Permalink
Modified bunch HashMaps in Spark to use TimeStampedHashMap and made v…
Browse files Browse the repository at this point in the history
…arious modules use CleanupTask to periodically clean up metadata.
  • Loading branch information
tdas committed Nov 27, 2012
1 parent 0fe2fc4 commit b18d708
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 18 deletions.
6 changes: 5 additions & 1 deletion core/src/main/scala/spark/CacheTracker.scala
Expand Up @@ -14,6 +14,7 @@ import scala.collection.mutable.HashSet

import spark.storage.BlockManager
import spark.storage.StorageLevel
import util.{CleanupTask, TimeStampedHashMap}

private[spark] sealed trait CacheTrackerMessage

Expand All @@ -30,14 +31,16 @@ private[spark] case object StopCacheTracker extends CacheTrackerMessage

private[spark] class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples
private val locs = new HashMap[Int, Array[List[String]]]
private val locs = new TimeStampedHashMap[Int, Array[List[String]]]

/**
* A map from the slave's host name to its cache size.
*/
private val slaveCapacity = new HashMap[String, Long]
private val slaveUsage = new HashMap[String, Long]

private val cleanupTask = new CleanupTask("CacheTracker", locs.cleanup)

private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
Expand Down Expand Up @@ -86,6 +89,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging {
case StopCacheTracker =>
logInfo("Stopping CacheTrackerActor")
sender ! true
cleanupTask.cancel()
context.stop(self)
}
}
Expand Down
27 changes: 18 additions & 9 deletions core/src/main/scala/spark/MapOutputTracker.scala
Expand Up @@ -17,6 +17,7 @@ import scala.collection.mutable.HashSet
import scheduler.MapStatus
import spark.storage.BlockManagerId
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import util.{CleanupTask, TimeStampedHashMap}

private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
Expand All @@ -43,7 +44,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea

val timeout = 10.seconds

var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]

// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
Expand All @@ -52,7 +53,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea

// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]

var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
Expand All @@ -63,6 +64,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
actorSystem.actorFor(url)
}

val cleanupTask = new CleanupTask("MapOutputTracker", this.cleanup)

// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
Expand All @@ -83,14 +86,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}

def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.get(shuffleId) != null) {
if (mapStatuses.get(shuffleId) != None) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}

def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses.get(shuffleId)
var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
Expand All @@ -107,7 +110,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}

def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var array = mapStatuses.get(shuffleId)
var array = mapStatuses(shuffleId)
if (array != null) {
array.synchronized {
if (array(mapId).address == bmAddress) {
Expand All @@ -125,7 +128,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea

// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId)
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
Expand All @@ -138,7 +141,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case e: InterruptedException =>
}
}
return mapStatuses.get(shuffleId).map(status =>
return mapStatuses(shuffleId).map(status =>
(status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
} else {
fetching += shuffleId
Expand All @@ -164,9 +167,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
}

def cleanup(cleanupTime: Long) {
mapStatuses.cleanup(cleanupTime)
cachedSerializedStatuses.cleanup(cleanupTime)
}

def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
cleanupTask.cancel()
trackerActor = null
}

Expand All @@ -192,7 +201,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
generation = newGen
}
}
Expand All @@ -210,7 +219,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case Some(bytes) =>
return bytes
case None =>
statuses = mapStatuses.get(shuffleId)
statuses = mapStatuses(shuffleId)
generationGotten = generation
}
}
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/spark/scheduler/DAGScheduler.scala
Expand Up @@ -14,6 +14,7 @@ import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
import util.{CleanupTask, TimeStampedHashMap}

/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
Expand Down Expand Up @@ -61,9 +62,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with

val nextStageId = new AtomicInteger(0)

val idToStage = new HashMap[Int, Stage]
val idToStage = new TimeStampedHashMap[Int, Stage]

val shuffleToMapStage = new HashMap[Int, Stage]
val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]

var cacheLocs = new HashMap[Int, Array[List[String]]]

Expand All @@ -83,6 +84,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]

val cleanupTask = new CleanupTask("DAGScheduler", this.cleanup)

// Start a thread to run the DAGScheduler event loop
new Thread("DAGScheduler") {
setDaemon(true)
Expand Down Expand Up @@ -591,8 +594,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}

def cleanup(cleanupTime: Long) {
idToStage.cleanup(cleanupTime)
shuffleToMapStage.cleanup(cleanupTime)
}

def stop() {
eventQueue.put(StopDAGScheduler)
cleanupTask.cancel()
taskSched.stop()
}
}
6 changes: 4 additions & 2 deletions core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
Expand Up @@ -14,17 +14,19 @@ import com.ning.compress.lzf.LZFOutputStream

import spark._
import spark.storage._
import util.{TimeStampedHashMap, CleanupTask}

private[spark] object ShuffleMapTask {

// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new JHashMap[Int, Array[Byte]]
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
val cleanupTask = new CleanupTask("ShuffleMapTask", serializedInfoCache.cleanup)

def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId)
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
return old
} else {
Expand Down
31 changes: 31 additions & 0 deletions core/src/main/scala/spark/util/CleanupTask.scala
@@ -0,0 +1,31 @@
package spark.util

import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors}
import java.util.{TimerTask, Timer}
import spark.Logging

class CleanupTask(name: String, cleanupFunc: (Long) => Unit) extends Logging {
val delayMins = System.getProperty("spark.cleanup.delay", "-100").toInt
val periodMins = System.getProperty("spark.cleanup.period", (delayMins / 10).toString).toInt
val timer = new Timer(name + " cleanup timer", true)
val task = new TimerTask {
def run() {
try {
if (delayMins > 0) {

cleanupFunc(System.currentTimeMillis() - (delayMins * 60 * 1000))
logInfo("Ran cleanup task for " + name)
}
} catch {
case e: Exception => logError("Error running cleanup task for " + name, e)
}
}
}
if (periodMins > 0) {
timer.schedule(task, periodMins * 60 * 1000, periodMins * 60 * 1000)
}

def cancel() {
timer.cancel()
}
}
87 changes: 87 additions & 0 deletions core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -0,0 +1,87 @@
package spark.util

import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, Map}
import java.util.concurrent.ConcurrentHashMap

/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
* time stamp along with each key-value pair. Key-value pairs that are older than a particular
* threshold time can them be removed using the cleanup method. This is intended to be a drop-in
* replacement of scala.collection.mutable.HashMap.
*/
class TimeStampedHashMap[A, B] extends Map[A, B]() {
val internalMap = new ConcurrentHashMap[A, (B, Long)]()

def get(key: A): Option[B] = {
val value = internalMap.get(key)
if (value != null) Some(value._1) else None
}

def iterator: Iterator[(A, B)] = {
val jIterator = internalMap.entrySet().iterator()
jIterator.map(kv => (kv.getKey, kv.getValue._1))
}

override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
val newMap = new TimeStampedHashMap[A, B1]
newMap.internalMap.putAll(this.internalMap)
newMap.internalMap.put(kv._1, (kv._2, currentTime))
newMap
}

override def - (key: A): Map[A, B] = {
internalMap.remove(key)
this
}

override def += (kv: (A, B)): this.type = {
internalMap.put(kv._1, (kv._2, currentTime))
this
}

override def -= (key: A): this.type = {
internalMap.remove(key)
this
}

override def update(key: A, value: B) {
this += ((key, value))
}

override def apply(key: A): B = {
val value = internalMap.get(key)
if (value == null) throw new NoSuchElementException()
value._1
}

override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
internalMap.map(kv => (kv._1, kv._2._1)).filter(p)
}

override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()

override def size(): Int = internalMap.size()

override def foreach[U](f: ((A, B)) => U): Unit = {
val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val kv = (entry.getKey, entry.getValue._1)
f(kv)
}
}

def cleanup(threshTime: Long) {
val iterator = internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue._2 < threshTime) {
iterator.remove()
}
}
}

private def currentTime: Long = System.currentTimeMillis()

}
13 changes: 9 additions & 4 deletions streaming/src/main/scala/spark/streaming/StreamingContext.scala
Expand Up @@ -43,7 +43,7 @@ class StreamingContext private (
* @param batchDuration The time interval at which streaming data will be divided into batches
*/
def this(master: String, frameworkName: String, batchDuration: Time) =
this(new SparkContext(master, frameworkName), null, batchDuration)
this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration)

/**
* Recreates the StreamingContext from a checkpoint file.
Expand Down Expand Up @@ -214,11 +214,8 @@ class StreamingContext private (
"Checkpoint directory has been set, but the graph checkpointing interval has " +
"not been set. Please use StreamingContext.checkpoint() to set the interval."
)


}


/**
* This function starts the execution of the streams.
*/
Expand Down Expand Up @@ -265,6 +262,14 @@ class StreamingContext private (


object StreamingContext {

def createNewSparkContext(master: String, frameworkName: String): SparkContext = {
if (System.getProperty("spark.cleanup.delay", "-1").toInt < 0) {
System.setProperty("spark.cleanup.delay", "60")
}
new SparkContext(master, frameworkName)
}

implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = {
new PairDStreamFunctions[K, V](stream)
}
Expand Down

0 comments on commit b18d708

Please sign in to comment.