Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2033] Automatically cleanup checkpoint #855

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 32 additions & 12 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference}
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{RDDCheckpointData, RDD}
import org.apache.spark.util.Utils

/**
Expand All @@ -33,6 +33,7 @@ private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
private case class CleanAccum(accId: Long) extends CleanupTask
private case class CleanCheckpoint(rddId: Int) extends CleanupTask

/**
* A WeakReference associated with a CleanupTask.
Expand Down Expand Up @@ -94,12 +95,12 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
@volatile private var stopped = false

/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener) {
def attachListener(listener: CleanerListener): Unit = {
listeners += listener
}

/** Start the cleaner. */
def start() {
def start(): Unit = {
cleaningThread.setDaemon(true)
cleaningThread.setName("Spark Context Cleaner")
cleaningThread.start()
Expand All @@ -108,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/**
* Stop the cleaning thread and wait until the thread has finished running its current task.
*/
def stop() {
def stop(): Unit = {
stopped = true
// Interrupt the cleaning thread, but wait until the current task has finished before
// doing so. This guards against the race condition where a cleaning thread may
Expand All @@ -121,7 +122,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Register a RDD for cleanup when it is garbage collected. */
def registerRDDForCleanup(rdd: RDD[_]) {
def registerRDDForCleanup(rdd: RDD[_]): Unit = {
registerForCleanup(rdd, CleanRDD(rdd.id))
}

Expand All @@ -130,17 +131,22 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]): Unit = {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}

/** Register a Broadcast for cleanup when it is garbage collected. */
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]): Unit = {
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
}

/** Register a RDDCheckpointData for cleanup when it is garbage collected. */
def registerRDDCheckpointDataForCleanup[T](rdd: RDD[_], parentId: Int): Unit = {
registerForCleanup(rdd, CleanCheckpoint(parentId))
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
}

Expand All @@ -164,6 +170,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
case CleanAccum(accId) =>
doCleanupAccum(accId, blocking = blockOnCleanupTasks)
case CleanCheckpoint(rddId) =>
doCleanCheckpoint(rddId)
}
}
}
Expand All @@ -175,7 +183,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Perform RDD cleanup. */
def doCleanupRDD(rddId: Int, blocking: Boolean) {
def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = {
try {
logDebug("Cleaning RDD " + rddId)
sc.unpersistRDD(rddId, blocking)
Expand All @@ -187,7 +195,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Perform shuffle cleanup, asynchronously. */
def doCleanupShuffle(shuffleId: Int, blocking: Boolean) {
def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = {
try {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
Expand All @@ -200,7 +208,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Perform broadcast cleanup. */
def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
def doCleanupBroadcast(broadcastId: Long, blocking: Boolean): Unit = {
try {
logDebug(s"Cleaning broadcast $broadcastId")
broadcastManager.unbroadcast(broadcastId, true, blocking)
Expand All @@ -212,7 +220,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Perform accumulator cleanup. */
def doCleanupAccum(accId: Long, blocking: Boolean) {
def doCleanupAccum(accId: Long, blocking: Boolean): Unit = {
try {
logDebug("Cleaning accumulator " + accId)
Accumulators.remove(accId)
Expand All @@ -223,6 +231,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}

/** Perform checkpoint cleanup. */
def doCleanCheckpoint(rddId: Int): Unit = {
try {
logDebug("Cleaning rdd checkpoint data " + rddId)
RDDCheckpointData.clearRDDCheckpointData(sc, rddId)
logInfo("Cleaned rdd checkpoint data " + rddId)
}
catch {
case e: Exception => logError("Error cleaning rdd checkpoint data " + rddId, e)
}
}

private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
Expand Down
27 changes: 22 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.reflect.ClassTag

import org.apache.hadoop.fs.Path

import org.apache.spark.{Logging, Partition, SerializableWritable, SparkException}
import org.apache.spark._
import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask}

/**
Expand Down Expand Up @@ -83,7 +83,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}

// Create the output path for the checkpoint
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path)
Expand All @@ -92,8 +92,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableWritable(rdd.context.hadoopConfiguration))
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should have this config. If the cleaner is not defined then we probably won't clean the checkpoints anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Streaming related checkpoints should not be automatically cleared. The default should not automatically clean up checkpoints ,but it can set to ture by the user.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, that's fine. This is also an internal config so we can always change it later.

rdd.context.cleaner.foreach { cleaner =>
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
}
}
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
"Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " +
Expand Down Expand Up @@ -130,5 +135,17 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}

// Used for synchronization
private[spark] object RDDCheckpointData
private[spark] object RDDCheckpointData {
def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) }
}

def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
rddCheckpointDataPath(sc, rddId).foreach { path =>
val fs = path.getFileSystem(sc.hadoopConfiguration)
if (fs.exists(path)) {
fs.delete(path, true)
}
}
}
}
49 changes: 48 additions & 1 deletion core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.scalatest.concurrent.{PatienceConfiguration, Eventually}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.{RDDCheckpointData, RDD}
import org.apache.spark.storage._
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
Expand Down Expand Up @@ -205,6 +206,52 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
postGCTester.assertCleanup()
}

test("automatically cleanup checkpoint") {
val checkpointDir = java.io.File.createTempFile("temp", "")
checkpointDir.deleteOnExit()
checkpointDir.delete()
var rdd = newPairRDD
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
rdd.collect()
var rddId = rdd.id

// Confirm the checkpoint directory exists
assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined)
val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get
val fs = path.getFileSystem(sc.hadoopConfiguration)
assert(fs.exists(path))

// the checkpoint is not cleaned by default (without the configuration set)
var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil)
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))

sc.stop()
val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint").
set("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
sc = new SparkContext(conf)
rdd = newPairRDD
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
rdd.collect()
rddId = rdd.id

// Confirm the checkpoint directory exists
assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))

// Test that GC causes checkpoint data cleanup after dereferencing the RDD
postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil)
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
}

test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
Expand Down