Skip to content

Commit

Permalink
Address Andrew's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Jan 30, 2016
1 parent ef3983b commit 97e39c0
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 35 deletions.
18 changes: 13 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1535,9 +1535,14 @@ abstract class RDD[T: ClassTag](

private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None

// Whether checkpoint all RDDs that are marked with the checkpoint flag.
// Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default,
// we stop as soon as we find the first such RDD, an optimization that allows us to write
// less data but is not safe for all workloads. E.g. in streaming we may checkpoint both
// an RDD and its parent in every batch, in which case the parent may never be checkpointed
// and its lineage never truncated, leading to OOMs in the long run (SPARK-6847).
private val checkpointAllMarked =
Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED)).map(_.toBoolean).getOrElse(false)
Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS))
.map(_.toBoolean).getOrElse(false)

/** Returns the first parent RDD */
protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
Expand Down Expand Up @@ -1583,8 +1588,10 @@ abstract class RDD[T: ClassTag](
doCheckpointCalled = true
if (checkpointData.isDefined) {
if (checkpointAllMarked) {
// Checkpoint dependencies first because dependencies will be set to
// ReliableCheckpointRDD after checkpointing.
// TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint
// them in parallel.
// Checkpoint parents first because our lineage will be truncated after we
// checkpoint ourselves
dependencies.foreach(_.rdd.doCheckpoint())
}
checkpointData.get.checkpoint()
Expand Down Expand Up @@ -1706,7 +1713,8 @@ abstract class RDD[T: ClassTag](
*/
object RDD {

private[spark] val CHECKPOINT_ALL_MARKED = "spark.checkpoint.checkpointAllMarked"
private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS =
"spark.checkpoint.checkpointAllMarkedAncestors"

// The following implicit functions were in SparkContext before 1.3 and users had to
// `import SparkContext._` to enable them. Now we move them here to make the compiler find
Expand Down
14 changes: 10 additions & 4 deletions core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -513,18 +513,24 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
assert(rdd.partitions.size === 0)
}

runTest("checkpoint all marked RDDs") { reliableCheckpoint: Boolean =>
sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED, "true")
runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean =>
testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true)
testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false)
}

private def testCheckpointAllMarkedAncestors(
reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = {
sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString)
try {
val rdd1 = sc.parallelize(1 to 10)
checkpoint(rdd1, reliableCheckpoint)
val rdd2 = rdd1.map(_ + 1)
checkpoint(rdd2, reliableCheckpoint)
rdd2.count()
assert(rdd1.isCheckpointed === true)
assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors)
assert(rdd2.isCheckpointed === true)
} finally {
sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED, null)
sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
// Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed.
SparkEnv.set(ssc.env)

// Enable "spark.checkpoint.checkpointAllMarked" to make sure that all RDDs marked with the
// checkpoint flag are all checkpointed to avoid the stack overflow issue. See SPARK-6847
ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED, "true")
// Checkpoint all RDDs marked for checkpointing to ensure their lineages are
// truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
Try {
jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
graph.generateJobs(time) // generate jobs using allocated block
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
s"""Streaming job from <a href="$batchUrl">$batchLinkText</a>""")
ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString)
ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString)
// Enable "spark.checkpoint.checkpointAllMarked" to make sure that all RDDs marked with the
// checkpoint flag are all checkpointed to avoid the stack overflow issue. See SPARK-6847
ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED, "true")
// Checkpoint all RDDs marked for checkpointing to ensure their lineages are
// truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")

// We need to assign `eventLoop` to a temp variable. Otherwise, because
// `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,12 +828,16 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
//
// 1) input rdd input rdd input rdd
// | | |
// v v v
// 2) cogroup rdd ---> cogroup rdd ---> cogroup rdd ...
// | / | / |
// v / v / v
// 3) map rdd --- map rdd --- map rdd ...
// |
// | | |
// v v v
// 4) cogroup rdd ---> cogroup rdd ---> cogroup rdd ...
// | / | / |
// v / v / v
// 5) map rdd --- map rdd --- map rdd ...
//
// Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to
Expand All @@ -853,35 +857,36 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
Some(values.sum + state.getOrElse(0))
}
@volatile var checkpointAllMarkedRDDsEnable = false
@volatile var shouldCheckpointAllMarkedRDDs = false
@volatile var rddsCheckpointed = false
inputDStream.map(i => (i, i))
.updateStateByKey(updateFunc).checkpoint(batchDuration)
.updateStateByKey(updateFunc).checkpoint(batchDuration)
.foreachRDD { rdd =>
checkpointAllMarkedRDDsEnable =
Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED)).
map(_.toBoolean).getOrElse(false)

val stateRDDs = {
def findAllMarkedRDDs(_rdd: RDD[_], buffer: ArrayBuffer[RDD[_]]): Unit = {
if (_rdd.checkpointData.isDefined) {
buffer += _rdd
}
_rdd.dependencies.foreach(dep => findAllMarkedRDDs(dep.rdd, buffer))
/**
* Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors.
*/
def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = {
val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList
if (rdd.checkpointData.isDefined) {
rdd :: markedRDDs
} else {
markedRDDs
}
}

shouldCheckpointAllMarkedRDDs =
Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).
map(_.toBoolean).getOrElse(false)

val buffer = new ArrayBuffer[RDD[_]]
findAllMarkedRDDs(rdd, buffer)
buffer.toSeq
val stateRDDs = findAllMarkedRDDs(rdd)
rdd.count()
// Check the two state RDDs are both checkpointed
rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed)
}
rdd.count()
// Check the two state RDDs are both checkpointed
rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed)
}
ssc.start()
batchCounter.waitUntilBatchesCompleted(1, 10000)
assert(checkpointAllMarkedRDDsEnable === true)
assert(shouldCheckpointAllMarkedRDDs === true)
assert(rddsCheckpointed === true)
}

Expand Down

0 comments on commit 97e39c0

Please sign in to comment.