Skip to content

Commit

Permalink
Always add the object to driver's block manager.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Aug 19, 2014
1 parent 0d8ed5b commit 5bacb9d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](
*/
@transient private var _value: T = obj

private val broadcastId = BroadcastBlockId(id)

/** Total number of blocks this broadcast variable contains. */
private val numBlocks: Int = writeBlocks()

private val broadcastId = BroadcastBlockId(id)

override protected def getValue() = _value

/**
Expand All @@ -75,40 +75,57 @@ private[spark] class TorrentBroadcast[T: ClassTag](
* @return number of blocks this broadcast variable is divided into
*/
private def writeBlocks(): Int = {
val blocks = TorrentBroadcast.blockifyObject(_value)
blocks.zipWithIndex.foreach { case (block, i) =>
SparkEnv.get.blockManager.putBytes(
BroadcastBlockId(id, "piece" + i),
block,
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)
// For local mode, just put the object in the BlockManager so we can find it later.
SparkEnv.get.blockManager.putSingle(
broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)

if (!isLocal) {
val blocks = TorrentBroadcast.blockifyObject(_value)
blocks.zipWithIndex.foreach { case (block, i) =>
SparkEnv.get.blockManager.putBytes(
BroadcastBlockId(id, "piece" + i),
block,
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)
}
blocks.length
} else {
0
}
blocks.length
}

/** Fetch torrent blocks from the driver and/or other executors. */
private def readBlocks(): Array[ByteBuffer] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
val blocks = new Array[ByteBuffer](numBlocks)
val bm = SparkEnv.get.blockManager

for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
// Note that we use getBytes rather than getRemoteBytes here because there is a chance
// that previous attempts to fetch the broadcast blocks have already fetched some of the
// blocks. In that case, some blocks would be available locally (on this executor).
SparkEnv.get.blockManager.getBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
SparkEnv.get.blockManager.putBytes(
pieceId,
block,
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)

case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
var blockOpt = bm.getLocalBytes(pieceId)
if (!blockOpt.isDefined) {
blockOpt = bm.getRemoteBytes(pieceId)
blockOpt match {
case Some(block) =>
// If we found the block from remote executors/driver's BlockManager, put the block
// in this executor's BlockManager.
SparkEnv.get.blockManager.putBytes(
pieceId,
block,
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)

case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
}
// If we get here, the option is defined.
blocks(pid) = blockOpt.get
}
blocks
}
Expand Down
10 changes: 0 additions & 10 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,6 @@ private[spark] class BlockManager(
None
}

def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val local = getLocalBytes(blockId)
if (local.isDefined) {
local
} else {
val remote = getRemoteBytes(blockId)
remote
}
}

/**
* Get a block from the block manager (either local or remote).
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,25 +189,19 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
var blockId = BroadcastBlockId(broadcastId)
var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 0)
assert(statuses.size === 1)

blockId = BroadcastBlockId(broadcastId, "piece0")
statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
assert(bm.executorId === "<driver>", "Block should only be on the driver")
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK_SER)
assert(status.memSize > 0, "Block should be in memory store on the driver")
assert(status.diskSize === 0, "Block should not be in disk store on the driver")
}
assert(statuses.size === (if (distributed) 1 else 0))
}

// Verify that blocks are persisted in both the executors and the driver
def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
var blockId = BroadcastBlockId(broadcastId)
var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
if (distributed) {
assert(statuses.size === numSlaves)
assert(statuses.size === numSlaves + 1)
} else {
assert(statuses.size === 1)
}
Expand All @@ -217,20 +211,20 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
if (distributed) {
assert(statuses.size === numSlaves + 1)
} else {
assert(statuses.size === 1)
assert(statuses.size === 0)
}
}

// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true.
def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
var blockId = BroadcastBlockId(broadcastId)
var expectedNumBlocks = if (removeFromDriver) 0 else if (distributed) 0 else 1
var expectedNumBlocks = if (removeFromDriver) 0 else 1
var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === expectedNumBlocks)

blockId = BroadcastBlockId(broadcastId, "piece0")
expectedNumBlocks = if (removeFromDriver) 0 else 1
expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1
statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === expectedNumBlocks)
}
Expand Down

0 comments on commit 5bacb9d

Please sign in to comment.