Skip to content

Commit

Permalink
[SPARK-3119] Re-implementation of TorrentBroadcast.
Browse files Browse the repository at this point in the history
This is a re-implementation of TorrentBroadcast, with the following changes:

1. Removes most of the mutable, transient state from TorrentBroadcast (e.g. totalBytes, num of blocks fetched).
2. Removes TorrentInfo and TorrentBlock
3. Replaces the BlockManager.getSingle call in readObject with a getLocal, resuling in one less RPC call to the BlockManagerMasterActor to find the location of the block.
4. Removes the metadata block, resulting in one less block to fetch.
5. Removes an extra memory copy for deserialization (by using Java's SequenceInputStream).
  • Loading branch information
rxin committed Aug 19, 2014
1 parent 8257733 commit c1185cd
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 239 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
trait BroadcastFactory {

def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit

/**
* Creates a new broadcast variable.
*
* @param value value to broadcast
* @param isLocal whether we are in local mode (single JVM process)
* @param id unique id representing this broadcast variable
*/
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]

def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit

def stop(): Unit
}
264 changes: 97 additions & 167 deletions core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.broadcast

import java.io._
import java.nio.ByteBuffer

import scala.collection.JavaConversions.asJavaEnumeration
import scala.reflect.ClassTag
import scala.util.Random

Expand All @@ -27,41 +29,87 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}

/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
* protocol to do a distributed transfer of the broadcasted data to the executors.
* The mechanism is as follows. The driver divides the serializes the broadcasted data,
* divides it into smaller chunks, and stores them in the BlockManager of the driver.
* These chunks are reported to the BlockManagerMaster so that all the executors can
* learn the location of those chunks. The first time the broadcast variable (sent as
* part of task) is deserialized at a executor, all the chunks are fetched using
* the BlockManager. When all the chunks are fetched (initially from the driver's
* BlockManager), they are combined and deserialized to recreate the broadcasted data.
* However, the chunks are also stored in the BlockManager and reported to the
* BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
* multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
* made to other executors who already have those chunks, resulting in a distributed
* fetching. This prevents the driver from being the bottleneck in sending out multiple
* copies of the broadcast data (one per executor) as done by the
* [[org.apache.spark.broadcast.HttpBroadcast]].
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
*
* The mechanism is as follows:
*
* The driver divides the serialized object into small chunks and
* stores those chunks in the BlockManager of the driver.
*
* On each executor, the executor first attempts to fetch the object from its BlockManager. If
* it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
* other executors if available. Once it gets the chunks, it puts the chunks in its own
* BlockManager, ready for other executors to fetch from.
*
* This prevents the driver from being the bottleneck in sending out multiple copies of the
* broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
*
* @param obj object to broadcast
* @param isLocal whether Spark is running in local mode (single JVM process).
* @param id A unique identifier for the broadcast variable.
*/
private[spark] class TorrentBroadcast[T: ClassTag](
@transient var value_ : T, isLocal: Boolean, id: Long)
obj : T,
@transient private val isLocal: Boolean,
id: Long)
extends Broadcast[T](id) with Logging with Serializable {

override protected def getValue() = value_
override protected def getValue() = _value

/**
* Value of the broadcast object. On driver, this is set directly by the constructor.
* On executors, this is reconstructed by [[readObject]], which builds this value by reading
* blocks from the driver and/or other executors.
*/
@transient private var _value: T = obj

/** Total number of blocks this broadcast variable contains. */
private val numBlocks: Int = writeBlocks()

private val broadcastId = BroadcastBlockId(id)

SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
*
* @return number of blocks this broadcast variable is divided into
*/
private def writeBlocks(): Int = {
val blocks = TorrentBroadcast.blockifyObject(_value)
blocks.zipWithIndex.foreach { case (block, i) =>
// TODO: Use putBytes directly.
SparkEnv.get.blockManager.putSingle(
BroadcastBlockId(id, "piece" + i),
blocks(i),
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)
}
blocks.length
}

@transient private var arrayOfBlocks: Array[TorrentBlock] = null
@transient private var totalBlocks = -1
@transient private var totalBytes = -1
@transient private var hasBlocks = 0
/** Fetch torrent blocks from the driver and/or other executors. */
private def readBlocks(): Array[Array[Byte]] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these thunks from this executor as well.
var numBlocksAvailable = 0
val blocks = new Array[Array[Byte]](numBlocks)

if (!isLocal) {
sendBroadcast()
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
blocks(pid) = x.asInstanceOf[Array[Byte]]
numBlocksAvailable += 1
SparkEnv.get.blockManager.putBytes(
pieceId,
ByteBuffer.wrap(blocks(pid)),
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)

case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
}
blocks
}

/**
Expand All @@ -79,26 +127,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}

private def sendBroadcast() {
val tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
hasBlocks = tInfo.totalBlocks

// Store meta-info
val metaId = BroadcastBlockId(id, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
SparkEnv.get.blockManager.putSingle(
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)

// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastBlockId(id, "piece" + i)
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
}
}

/** Used by the JVM when serializing this object. */
private def writeObject(out: ObjectOutputStream) {
assertValid()
Expand All @@ -109,99 +137,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(broadcastId) match {
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
_value = x.asInstanceOf[T]

case None =>
val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)

// Initialize @transient variables that will receive garbage values from the master.
resetWorkerVariables()

if (receiveBroadcast()) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)

/* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
* This creates a trade-off between memory usage and latency. Storing copy doubles
* the memory footprint; not storing doubles deserialization cost. Also,
* this does not need to be reported to BlockManagerMaster since other executors
* does not need to access this block (they only need to fetch the chunks,
* which are reported).
*/
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)

// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
} else {
logError("Reading broadcast variable " + id + " failed")
}

val time = (System.nanoTime - start) / 1e9
val start = System.nanoTime()
val blocks = readBlocks()
val time = (System.nanoTime() - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}

private def resetWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
}

private def receiveBroadcast(): Boolean = {
// Receive meta-info about the size of broadcast data,
// the number of chunks it is divided into, etc.
val metaId = BroadcastBlockId(id, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
SparkEnv.get.blockManager.getSingle(metaId) match {
case Some(x) =>
val tInfo = x.asInstanceOf[TorrentInfo]
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
hasBlocks = 0

case None =>
Thread.sleep(500)
}
attemptId -= 1
}

if (totalBlocks == -1) {
return false
}

/*
* Fetch actual chunks of data. Note that all these chunks are stored in
* the BlockManager and reported to the master, so that other executors
* can find out and pull the chunks from this executor.
*/
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
_value = TorrentBroadcast.unBlockifyObject[T](blocks)
// Store the merged copy in BlockManager so other tasks on this executor doesn't
// need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)

case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
}

hasBlocks == totalBlocks
}

}

private[broadcast] object TorrentBroadcast extends Logging {

private object TorrentBroadcast extends Logging {
/** Size of each block. Default value is 4MB. */
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
Expand All @@ -223,52 +182,37 @@ private[broadcast] object TorrentBroadcast extends Logging {
initialized = false
}

def blockifyObject[T: ClassTag](obj: T): TorrentInfo = {
def blockifyObject[T: ClassTag](obj: T): Array[Array[Byte]] = {
// TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
// so we don't need to do the extra memory copy.
val bos = new ByteArrayOutputStream()
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
val ser = SparkEnv.get.serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
val byteArray = bos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
val blocks = new Array[Array[Byte]](numBlocks)

var blockNum = byteArray.length / BLOCK_SIZE
if (byteArray.length % BLOCK_SIZE != 0) {
blockNum += 1
}

val blocks = new Array[TorrentBlock](blockNum)
var blockId = 0

for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
val tempByteArray = new Array[Byte](thisBlockSize)
bais.read(tempByteArray, 0, thisBlockSize)

blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
blocks(blockId) = tempByteArray
blockId += 1
}
bais.close()

val info = TorrentInfo(blocks, blockNum, byteArray.length)
info.hasBlocks = blockNum
info
blocks
}

def unBlockifyObject[T: ClassTag](
arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
val retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
}
def unBlockifyObject[T: ClassTag](blocks: Array[Array[Byte]]): T = {
val is = new SequenceInputStream(
asJavaEnumeration(blocks.iterator.map(block => new ByteArrayInputStream(block))))
val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is

val in: InputStream = {
val arrIn = new ByteArrayInputStream(retByteArray)
if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
}
val ser = SparkEnv.get.serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
Expand All @@ -284,17 +228,3 @@ private[broadcast] object TorrentBroadcast extends Logging {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}

private[broadcast] case class TorrentBlock(
blockID: Int,
byteArray: Array[Byte])
extends Serializable

private[broadcast] case class TorrentInfo(
@transient arrayOfBlocks: Array[TorrentBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {

@transient var hasBlocks = 0
}
Loading

0 comments on commit c1185cd

Please sign in to comment.