Skip to content

Commit

Permalink
[SPARK-12817] Add BlockManager.getOrElseUpdate and remove CacheManager
Browse files Browse the repository at this point in the history
CacheManager directly calls MemoryStore.unrollSafely() and has its own logic for handling graceful fallback to disk when cached data does not fit in memory. However, this logic also exists inside of the MemoryStore itself, so this appears to be unnecessary duplication.

Thanks to the addition of block-level read/write locks in #10705, we can refactor the code to remove the CacheManager and replace it with an atomic `BlockManager.getOrElseUpdate()` method.

This pull request replaces / subsumes #10748.

/cc andrewor14 and nongli for review. Note that this changes the locking semantics of a couple of internal BlockManager methods (`doPut()` and `lockNewBlockForWriting`), so please pay attention to the Scaladoc changes and new test cases for those methods.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #11436 from JoshRosen/remove-cachemanager.
  • Loading branch information
JoshRosen authored and Andrew Or committed Mar 2, 2016
1 parent 8f8d8a2 commit d6969ff
Show file tree
Hide file tree
Showing 16 changed files with 365 additions and 597 deletions.
179 changes: 0 additions & 179 deletions core/src/main/scala/org/apache/spark/CacheManager.scala

This file was deleted.

4 changes: 0 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class SparkEnv (
private[spark] val rpcEnv: RpcEnv,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
Expand Down Expand Up @@ -333,8 +332,6 @@ object SparkEnv extends Logging {

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

val cacheManager = new CacheManager(blockManager)

val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
Expand Down Expand Up @@ -371,7 +368,6 @@ object SparkEnv extends Logging {
rpcEnv,
serializer,
closureSerializer,
cacheManager,
mapOutputTracker,
shuffleManager,
broadcastManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
val blockManager = SparkEnv.get.blockManager
if (blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
blockManager.releaseLock(broadcastId)
} else {
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
blocks.zipWithIndex.foreach { case (block, i) =>
val pieceId = BroadcastBlockId(id, "piece" + i)
if (blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) {
blockManager.releaseLock(pieceId)
} else {
if (!blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
Expand All @@ -130,22 +126,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
// If we found the block from remote executors/driver's BlockManager, put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, block, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
block
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
// We found the block from remote executors/driver's BlockManager, so put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = b
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
// At this point we are guaranteed to hold a read lock, since we either got the block locally
// or stored the remotely-fetched block and automatically downgraded the write lock.
blocks(pid) = block
releaseLock(pieceId)
}
blocks
}
Expand Down Expand Up @@ -191,9 +189,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
releaseLock(broadcastId)
} else {
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
obj
Expand Down
5 changes: 1 addition & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,8 @@ private[spark] class Executor(
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize >= maxRpcMessageSize) {
val blockId = TaskResultBlockId(taskId)
val putSucceeded = env.blockManager.putBytes(
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
if (putSucceeded) {
env.blockManager.releaseLock(blockId)
}
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ class NettyBlockRpcServer(
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
val blockId = BlockId(uploadBlock.blockId)
val putSucceeded = blockManager.putBlockData(blockId, data, level)
if (putSucceeded) {
blockManager.releaseLock(blockId)
}
blockManager.putBlockData(blockId, data, level)
responseContext.onSuccess(ByteBuffer.allocate(0))
}
}
Expand Down
33 changes: 31 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
Expand Down Expand Up @@ -272,7 +272,7 @@ abstract class RDD[T: ClassTag](
*/
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
Expand Down Expand Up @@ -314,6 +314,35 @@ abstract class RDD[T: ClassTag](
}
}

/**
* Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached.
*/
private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
case Left(blockResult) =>
if (readCachedBlock) {
val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod)
existingMetrics.incBytesReadInternal(blockResult.bytes)
new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
override def next(): T = {
existingMetrics.incRecordsReadInternal(1)
delegate.next()
}
}
} else {
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
}
case Right(iter) =>
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
}
}

/**
* Execute a block of code in a scope such that all new RDDs created in this body will
* be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}.
Expand Down

0 comments on commit d6969ff

Please sign in to comment.