From 01ebf560a73eb5d4207381b1796b49c670bfcee7 Mon Sep 17 00:00:00 2001 From: Mike Timper Date: Mon, 13 Apr 2015 15:23:18 -0600 Subject: [PATCH] Changes to support checkpointing to BlockRDD as described in http://apache-spark-user-list.1001560.n3.nabble.com/java-lang-StackOverflowError-when-calling-count-td5649.html#a11970 --- .../main/scala/org/apache/spark/rdd/RDD.scala | 22 ++++++++++++++--- .../apache/spark/rdd/RDDCheckpointData.scala | 24 +++++++++++++++---- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 582cfc1c6e3a6..65007f0ff524e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1199,9 +1199,19 @@ abstract class RDD[T: ClassTag]( * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() { - if (context.checkpointDir.isEmpty) { + checkpoint(false) + } + + def checkpoint(localStorage:Boolean) { + checkpointLocalStorage = localStorage + if (!localStorage && context.checkpointDir.isEmpty) { throw new Exception("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { + if (localStorage && (storageLevel.replication == 1 || !storageLevel.useDisk)) { + //Should this be an exception condition? + logWarning("When checkpointing to local storage it recommended that the RDD's " + + "storage level has replication and will spill to disk (e.g. MEMORY_AND_DISK_2)") + } checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() } @@ -1260,18 +1270,24 @@ abstract class RDD[T: ClassTag]( // Avoid handling doCheckpoint multiple times to prevent excessive recursion @transient private var doCheckpointCalled = false + private var checkpointLocalStorage = false + /** * Performs the checkpointing of this RDD by saving this. It is called after a job using this RDD * has completed (therefore the RDD has been materialized and potentially stored in memory). * doCheckpoint() is called recursively on the parent RDDs. */ private[spark] def doCheckpoint() { + doCheckpoint(checkpointLocalStorage) + } + + private[spark] def doCheckpoint(localStorage:Boolean) { if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { - checkpointData.get.doCheckpoint() + checkpointData.get.doCheckpoint(localStorage) } else { - dependencies.foreach(_.rdd.doCheckpoint()) + dependencies.foreach(_.rdd.doCheckpoint(localStorage)) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f1857979..8df900a39facf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, Partition, SerializableWritable, SparkException} import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} +import org.apache.spark.storage.StorageUtils /** * Enumeration to manage state transitions of an RDD through checkpointing @@ -70,8 +71,12 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) RDDCheckpointData.synchronized { cpFile } } - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. def doCheckpoint() { + doCheckpoint(false) + } + + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. + def doCheckpoint(localStorage:Boolean) { // If it is marked for checkpointing AND checkpointing is not already in progress, // then set it to be in progress, else return RDDCheckpointData.synchronized { @@ -82,18 +87,28 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } + var pathStr: String = "" + var newRDD:RDD[T] = null + if (localStorage) { + pathStr = "local storage" + val blockIDs = StorageUtils.getRddBlockLocations( + rdd.id, rdd.context.getExecutorStorageStatus).keys.toArray + newRDD = new BlockRDD[T](rdd.context, blockIDs) + } else { // Create the output path for the checkpoint val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { throw new SparkException("Failed to create checkpoint path " + path) } + pathStr = path.toString // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( new SerializableWritable(rdd.context.hadoopConfiguration)) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) - val newRDD = new CheckpointRDD[T](rdd.context, path.toString) + newRDD = new CheckpointRDD[T](rdd.context, pathStr) + } if (newRDD.partitions.size != rdd.partitions.size) { throw new SparkException( "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + @@ -102,12 +117,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Change the dependencies and partitions of the RDD RDDCheckpointData.synchronized { - cpFile = Some(path.toString) + cpFile = Some(pathStr) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) + logInfo("Done checkpointing RDD " + rdd.id + " to " + pathStr + + ", new parent is RDD " + newRDD.id) } // Get preferred location of a split after checkpointing