Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/CacheManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
case Some(blockResult) =>
// Partition is already materialized, so just return its values
context.taskMetrics.inputMetrics = Some(blockResult.inputMetrics)
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
new InterruptibleIterator(context, blockResult.dataAsIterator().asInstanceOf[Iterator[T]])

case None =>
// Acquire a lock for loading this partition
Expand Down Expand Up @@ -114,7 +114,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
logInfo(s"Whoever was loading $id failed; we'll try it ourselves")
loading.add(id)
}
values.map(_.data.asInstanceOf[Iterator[T]])
values.map(_.dataAsIterator().asInstanceOf[Iterator[T]])
}
}
}
Expand Down Expand Up @@ -144,7 +144,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
updatedBlocks ++=
blockManager.putIterator(key, values, level, tellMaster = true, effectiveStorageLevel)
blockManager.get(key) match {
case Some(v) => v.data.asInstanceOf[Iterator[T]]
case Some(v) => v.dataAsIterator().asInstanceOf[Iterator[T]]
case None =>
logInfo(s"Failure to store $key")
throw new BlockException(key, s"Block manager failed to return cached value for $key!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.storage.{BlockResult, BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{ByteBufferInputStream, Utils}
import org.apache.spark.util.io.ByteArrayChunkOutputStream

Expand Down Expand Up @@ -122,8 +122,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
def getLocal: Option[ByteBuffer] = bm.getLocal(pieceId).map(_.dataAsBytes())
def getRemote: Option[ByteBuffer] = bm.getRemote(pieceId).map(_.dataAsBytes()).map { block =>
// If we found the block from remote executors/driver's BlockManager, put the block
// in this executor's BlockManager.
SparkEnv.get.blockManager.putBytes(
Expand Down Expand Up @@ -164,9 +164,9 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
setConf(SparkEnv.get.conf)
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
x.asInstanceOf[T]
SparkEnv.get.blockManager.getLocal(broadcastId) match {
case Some(result) =>
result.dataAsIterator().next().asInstanceOf[T]

case None =>
logInfo("Started reading broadcast variable " + id)
Expand All @@ -184,7 +184,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
}
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,18 @@ import org.apache.spark.util.Utils
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
private[spark] class TaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl)
extends Logging {

private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4)
private val blockManager = env.blockManager

private val THREADS = env.conf.getInt("spark.resultGetter.threads", 4)
private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
THREADS, "task-result-getter")

protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
sparkEnv.closureSerializer.newInstance()
env.closureSerializer.newInstance()
}
}

Expand All @@ -56,12 +58,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
case IndirectTaskResult(blockId, size) =>
if (!taskSetManager.canFetchMoreResults(size)) {
// dropped by executor if size is larger than maxResultSize
sparkEnv.blockManager.master.removeBlock(blockId)
blockManager.master.removeBlock(blockId)
return
}
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
val serializedTaskResult = blockManager.getRemote(blockId).map(_.dataAsBytes())
if (!serializedTaskResult.isDefined) {
/* We won't be able to get the task result if the machine that ran the task failed
* between when the task ended and when we tried to fetch the result, or if the
Expand All @@ -72,7 +74,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
serializedTaskResult.get)
sparkEnv.blockManager.master.removeBlock(blockId)
blockManager.master.removeBlock(blockId)
(deserializedResult, size)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,15 @@ class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends Seria
this
}

override def flush() { output.flush() }
override def flush() {
output.flush()
// Kryo does not flush its underlying stream, so let's do that manually to preserve the expected
// semantics.
val underlyingStream = output.getOutputStream
if (underlyingStream != null) {
underlyingStream.flush()
}
}
override def close() { output.close() }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class FileShuffleBlockManager(conf: SparkConf)
*/
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,
writeMetrics: ShuffleWriteMetrics) = {
val blockSerde = new BlockSerializer(conf, serializer)
new ShuffleWriterGroup {
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
private val shuffleState = shuffleStates(shuffleId)
Expand All @@ -113,7 +114,7 @@ class FileShuffleBlockManager(conf: SparkConf)
fileGroup = getUnusedFileGroup()
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
new DiskBlockObjectWriter(conf, blockId, fileGroup(bucketId), bufferSize, blockSerde,
writeMetrics)
}
} else {
Expand All @@ -129,7 +130,7 @@ class FileShuffleBlockManager(conf: SparkConf)
logWarning(s"Failed to remove existing shuffle file $blockFile")
}
}
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
new DiskBlockObjectWriter(conf, blockId, blockFile, bufferSize, blockSerde, writeMetrics)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.storage._
import org.apache.spark.util.CompletionIterator

private[hash] object BlockStoreShuffleFetcher extends Logging {
Expand Down Expand Up @@ -77,7 +77,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
new BlockSerializer(SparkEnv.get.conf, serializer),
SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)
val itr = blockFetcherItr.flatMap(unpackBlock)

Expand Down
Loading