Skip to content

Commit

Permalink
Add framework for broadcast cleanup
Browse files Browse the repository at this point in the history
As of this commit, Spark does not clean up broadcast blocks.
This will be done in the next commit.
  • Loading branch information
andrewor14 committed Mar 26, 2014
1 parent ba52e00 commit d0edef3
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 112 deletions.
134 changes: 86 additions & 48 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,105 +21,106 @@ import java.lang.ref.{ReferenceQueue, WeakReference}

import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD

/** Listener class used for testing when any item has been cleaned by the Cleaner class */
private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
def shuffleCleaned(shuffleId: Int)
}
/**
* Classes that represent cleaning tasks.
*/
private sealed trait CleanupTask
private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask

/**
* Cleans RDDs and shuffle data.
* A WeakReference associated with a CleanupTask.
*
* When the referent object becomes only weakly reachable, the corresponding
* CleanupTaskWeakReference is automatically added to the given reference queue.
*/
private class CleanupTaskWeakReference(
val task: CleanupTask,
referent: AnyRef,
referenceQueue: ReferenceQueue[AnyRef])
extends WeakReference(referent, referenceQueue)

/**
* An asynchronous cleaner for RDD, shuffle, and broadcast state.
*
* This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest,
* to be processed when the associated object goes out of scope of the application. Actual
* cleanup is performed in a separate daemon thread.
*/
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

/** Classes to represent cleaning tasks */
private sealed trait CleanupTask
private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
// TODO: add CleanBroadcast
private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
with SynchronizedBuffer[CleanupTaskWeakReference]

private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask]
with SynchronizedBuffer[WeakReferenceWithCleanupTask]
private val referenceQueue = new ReferenceQueue[AnyRef]

private val listeners = new ArrayBuffer[CleanerListener]
with SynchronizedBuffer[CleanerListener]

private val cleaningThread = new Thread() { override def run() { keepCleaning() }}

private val REF_QUEUE_POLL_TIMEOUT = 100

@volatile private var stopped = false

private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask)
extends WeakReference(referent, referenceQueue)
/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener) {
listeners += listener
}

/** Start the cleaner */
/** Start the cleaner. */
def start() {
cleaningThread.setDaemon(true)
cleaningThread.setName("ContextCleaner")
cleaningThread.start()
}

/** Stop the cleaner */
/** Stop the cleaner. */
def stop() {
stopped = true
cleaningThread.interrupt()
}

/**
* Register a RDD for cleanup when it is garbage collected.
*/
/** Register a RDD for cleanup when it is garbage collected. */
def registerRDDForCleanup(rdd: RDD[_]) {
registerForCleanup(rdd, CleanRDD(rdd.id))
}

/**
* Register a shuffle dependency for cleanup when it is garbage collected.
*/
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}

/** Cleanup RDD. */
def cleanupRDD(rdd: RDD[_]) {
doCleanupRDD(rdd.id)
}

/** Cleanup shuffle. */
def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
doCleanupShuffle(shuffleDependency.shuffleId)
}

/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener) {
listeners += listener
/** Register a Broadcast for cleanup when it is garbage collected. */
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task)
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
}

/** Keep cleaning RDDs and shuffle data */
/** Keep cleaning RDD, shuffle, and broadcast state. */
private def keepCleaning() {
while (!isStopped) {
while (!stopped) {
try {
val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT))
.map(_.asInstanceOf[WeakReferenceWithCleanupTask])
val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
.map(_.asInstanceOf[CleanupTaskWeakReference])
reference.map(_.task).foreach { task =>
logDebug("Got cleaning task " + task)
referenceBuffer -= reference.get
task match {
case CleanRDD(rddId) => doCleanupRDD(rddId)
case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId)
case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId)
}
}
} catch {
case ie: InterruptedException =>
if (!isStopped) logWarning("Cleaning thread interrupted")
if (!stopped) logWarning("Cleaning thread interrupted")
case t: Throwable => logError("Error in cleaning thread", t)
}
}
Expand All @@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private def doCleanupRDD(rddId: Int) {
try {
logDebug("Cleaning RDD " + rddId)
sc.unpersistRDD(rddId, false)
sc.unpersistRDD(rddId, blocking = false)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned RDD " + rddId)
} catch {
Expand All @@ -150,10 +151,47 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}

private def mapOutputTrackerMaster =
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
/** Perform broadcast cleanup. */
private def doCleanupBroadcast(broadcastId: Long) {
try {
logDebug("Cleaning broadcast " + broadcastId)
broadcastManager.unbroadcast(broadcastId, removeFromDriver = true)
listeners.foreach(_.broadcastCleaned(broadcastId))
logInfo("Cleaned broadcast " + broadcastId)
} catch {
case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t)
}
}

private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]

// Used for testing

private[spark] def cleanupRDD(rdd: RDD[_]) {
doCleanupRDD(rdd.id)
}

private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
doCleanupShuffle(shuffleDependency.shuffleId)
}

private def isStopped = stopped
private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) {
doCleanupBroadcast(broadcast.id)
}

}

private object ContextCleaner {
private val REF_QUEUE_POLL_TIMEOUT = 100
}

/**
* Listener class used for testing when any item has been cleaned by the Cleaner class.
*/
private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
def shuffleCleaned(shuffleId: Int)
def broadcastCleaned(broadcastId: Long)
}
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,11 @@ class SparkContext(
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
def broadcast[T](value: T) = {
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
cleaner.registerBroadcastForCleanup(bc)
bc
}

/**
* Add a file to be downloaded with this Spark job on every node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ import java.io.Serializable
abstract class Broadcast[T](val id: Long) extends Serializable {
def value: T

/**
* Remove all persisted state associated with this broadcast.
* @param removeFromDriver Whether to remove state from the driver.
*/
def unpersist(removeFromDriver: Boolean)

// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ import org.apache.spark.SparkConf
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def unbroadcast(id: Long, removeFromDriver: Boolean)
def stop(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ private[spark] class BroadcastManager(
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}

def unbroadcast(id: Long, removeFromDriver: Boolean) {
broadcastFactory.unbroadcast(id, removeFromDriver)
}

}
Loading

0 comments on commit d0edef3

Please sign in to comment.