From 5ec645d945783457baed9e151337b2735c1b307f Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Fri, 28 Mar 2014 15:10:41 -0700 Subject: [PATCH] In memory shuffle (cherry-picked from amplab/graphx#135) --- .../spark/scheduler/ShuffleMapTask.scala | 7 +++++- .../apache/spark/storage/BlockManager.scala | 8 +++---- .../spark/storage/BlockObjectWriter.scala | 24 ++++++++----------- .../apache/spark/storage/MemoryStore.scala | 9 +++++++ .../spark/storage/ShuffleBlockManager.scala | 13 +++++++++- .../org/apache/spark/graphx/Pregel.scala | 7 ++++++ 6 files changed, 48 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 2a9edf4a76b97..1d68cefad657f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDDCheckpointData import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import java.nio.ByteBuffer private[spark] object ShuffleMapTask { @@ -168,7 +169,11 @@ private[spark] class ShuffleMapTask( var totalBytes = 0L var totalTime = 0L val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() + // writer.commit() + val bytes = writer.commit() + if (bytes != null) { + blockManager.putBytes(writer.blockId, ByteBuffer.wrap(bytes), StorageLevel.MEMORY_ONLY_SER, tellMaster = false) + } writer.close() val size = writer.fileSegment().length totalBytes += size diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 71584b6eb102a..21fe5cb881e6c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -57,7 +57,7 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. @@ -293,7 +293,7 @@ private[spark] class BlockManager( * never deletes (recent) items. */ def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - diskStore.getValues(blockId, serializer).orElse( + memoryStore.getValues(blockId, serializer).orElse( sys.error("Block " + blockId + " not found on disk, though it should be")) } @@ -313,7 +313,7 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - diskStore.getBytes(blockId) match { + memoryStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) case None => @@ -831,7 +831,7 @@ private[spark] class BlockManager( if (info != null) info.synchronized { // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) + val removedFromDisk = false //diskStore.remove(blockId) if (!removedFromMemory && !removedFromDisk) { logWarning("Block " + blockId + " could not be removed as it was not found in either " + "the disk or memory store") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 696b930a26b9e..4d7e1852f9c8f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, File, OutputStream} +import java.io.{ByteArrayOutputStream, FileOutputStream, File, OutputStream} import java.nio.channels.FileChannel import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -44,7 +44,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { * Flush the partial writes and commit them as a single atomic block. Return the * number of bytes written for this commit. */ - def commit(): Long + def commit(): Array[Byte] /** * Reverts writes that haven't been flushed yet. Callers should invoke this function @@ -106,7 +106,7 @@ private[spark] class DiskBlockObjectWriter( /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null private var bs: OutputStream = null - private var fos: FileOutputStream = null + private var fos: ByteArrayOutputStream = null private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private val initialPosition = file.length() @@ -115,9 +115,8 @@ private[spark] class DiskBlockObjectWriter( private var _timeWriting = 0L override def open(): BlockObjectWriter = { - fos = new FileOutputStream(file, true) + fos = new ByteArrayOutputStream() ts = new TimeTrackingOutputStream(fos) - channel = fos.getChannel() lastValidPosition = initialPosition bs = compressStream(new FastBufferedOutputStream(ts, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) @@ -130,9 +129,6 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - val start = System.nanoTime() - fos.getFD.sync() - _timeWriting += System.nanoTime() - start } objOut.close() @@ -149,18 +145,18 @@ private[spark] class DiskBlockObjectWriter( override def isOpen: Boolean = objOut != null - override def commit(): Long = { + override def commit(): Array[Byte] = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. objOut.flush() bs.flush() val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos + lastValidPosition = fos.size() + fos.toByteArray } else { // lastValidPosition is zero if stream is uninitialized - lastValidPosition + null } } @@ -170,7 +166,7 @@ private[spark] class DiskBlockObjectWriter( // truncate the file to the last valid position. objOut.flush() bs.flush() - channel.truncate(lastValidPosition) + throw new UnsupportedOperationException("Revert temporarily broken due to in memory shuffle code changes.") } } @@ -182,7 +178,7 @@ private[spark] class DiskBlockObjectWriter( } override def fileSegment(): FileSegment = { - new FileSegment(file, initialPosition, bytesWritten) + new FileSegment(null, initialPosition, bytesWritten) } // Only valid if called after close() diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 488f1ea9628f5..7afae0310626a 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable.ArrayBuffer import org.apache.spark.util.{SizeEstimator, Utils} +import org.apache.spark.serializer.Serializer /** * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as @@ -119,6 +120,14 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * A version of getValues that allows a custom serializer. This is used as part of the + * shuffle short-circuit code. + */ + def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { + getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) + } + override def remove(blockId: BlockId): Boolean = { entries.synchronized { val entry = entries.remove(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index bb07c8cb134cc..a80b197438793 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -187,6 +187,17 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } }) } + + def removeAllShuffleStuff() { + for (state <- shuffleStates.values; + group <- state.allFileGroups; + (mapId, _) <- group.mapIdToIndex.iterator; + reducerId <- 0 until group.files.length) { + val blockId = new ShuffleBlockId(group.shuffleId, mapId, reducerId) + blockManager.removeBlock(blockId, tellMaster = false) + } + shuffleStates.clear() + } } private[spark] @@ -200,7 +211,7 @@ object ShuffleBlockManager { * Stores the absolute index of each mapId in the files of this group. For instance, * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() + val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() /** * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 9b9c6a8e8ffcb..5d98d3b83b69b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -19,6 +19,7 @@ package org.apache.spark.graphx import scala.reflect.ClassTag import org.apache.spark.Logging +import org.apache.spark.SparkEnv /** @@ -143,6 +144,12 @@ object Pregel extends Logging { // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). activeMessages = messages.count() + + // Very ugly code to clear the in-memory shuffle data + messages.foreachPartition { iter => + SparkEnv.get.blockManager.shuffleBlockManager.removeAllShuffleStuff() + } + // Unpersist the RDDs hidden by newly-materialized RDDs oldMessages.unpersist(blocking=false) newVerts.unpersist(blocking=false)