Skip to content

Commit

Permalink
refactor blockmanager: data stored in memory is not encrypted, data w…
Browse files Browse the repository at this point in the history
…ritten to disk is encrypted
  • Loading branch information
uncleGen committed Feb 17, 2017
1 parent 63d909b commit f9a91d6
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 30 deletions.
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
private[spark] class SerializerManager(
defaultSerializer: Serializer,
conf: SparkConf,
encryptionKey: Option[Array[Byte]]) {
val encryptionKey: Option[Array[Byte]]) {

def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)

Expand Down Expand Up @@ -148,14 +148,14 @@ private[spark] class SerializerManager(
/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}

/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}

Expand Down
12 changes: 4 additions & 8 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Expand Up @@ -91,7 +91,7 @@ private[spark] class BlockManager(
// Actual storage of where blocks are kept
private[spark] val memoryStore =
new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
private[spark] val diskStore = new DiskStore(conf, diskBlockManager)
private[spark] val diskStore = new DiskStore(conf, serializerManager, diskBlockManager)
memoryManager.setMemoryStore(memoryStore)

// Note: depending on the memory manager, `maxMemory` may actually vary over time.
Expand Down Expand Up @@ -458,13 +458,11 @@ private[spark] class BlockManager(
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
} else if (level.useDisk && diskStore.contains(blockId)) {
val iterToReturn: Iterator[Any] = {
val diskBytes = diskStore.getBytes(blockId)
if (level.deserialized) {
val diskValues = serializerManager.dataDeserializeStream(
blockId,
diskBytes.toInputStream(dispose = true))(info.classTag)
val diskValues = diskStore.getBytesAsValues(blockId, info.classTag)
maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
} else {
val diskBytes = diskStore.getBytes(blockId)
val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)
.map {_.toInputStream(dispose = false)}
.getOrElse { diskBytes.toInputStream(dispose = true) }
Expand Down Expand Up @@ -807,9 +805,7 @@ private[spark] class BlockManager(
// Put it in memory first, even if it also has useDisk set to true;
// We will drop it to disk later if the memory store can't hold it.
val putSucceeded = if (level.deserialized) {
val values =
serializerManager.dataDeserializeStream(blockId, bytes.toInputStream())(classTag)
memoryStore.putIteratorAsValues(blockId, values, classTag) match {
memoryStore.putBytesAsValues(blockId, bytes, classTag) match {
case Right(_) => true
case Left(iter) =>
// If putting deserialized values in memory failed, we will put the bytes directly to
Expand Down
53 changes: 48 additions & 5 deletions core/src/main/scala/org/apache/spark/storage/DiskStore.scala
Expand Up @@ -21,17 +21,25 @@ import java.io.{FileOutputStream, IOException, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode

import com.google.common.io.Closeables
import scala.reflect.ClassTag

import com.google.common.io.{ByteStreams, Closeables}
import org.apache.commons.io.IOUtils

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
import org.apache.spark.SparkConf
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}
import org.apache.spark.util.io.ChunkedByteBuffer

/**
* Stores BlockManager blocks on disk.
*/
private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging {
private[spark] class DiskStore(
conf: SparkConf,
serializerManager: SerializerManager,
diskManager: DiskBlockManager) extends Logging {

private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m")

Expand Down Expand Up @@ -73,17 +81,52 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e
}

def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = {
val bytesToStore = if (serializerManager.encryptionEnabled) {
try {
val data = bytes.toByteBuffer
val in = new ByteBufferInputStream(data, true)
val byteBufOut = new ByteBufferOutputStream(data.remaining())
val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf,
serializerManager.encryptionKey.get)
try {
ByteStreams.copy(in, out)
} finally {
in.close()
out.close()
}
new ChunkedByteBuffer(byteBufOut.toByteBuffer)
} finally {
bytes.dispose()
}
} else {
bytes
}

put(blockId) { fileOutputStream =>
val channel = fileOutputStream.getChannel
Utils.tryWithSafeFinally {
bytes.writeFully(channel)
bytesToStore.writeFully(channel)
} {
channel.close()
}
}
}

def getBytes(blockId: BlockId): ChunkedByteBuffer = {
val bytes = readBytes(blockId)

val in = serializerManager.wrapForEncryption(bytes.toInputStream(dispose = true))
new ChunkedByteBuffer(ByteBuffer.wrap(IOUtils.toByteArray(in)))
}

def getBytesAsValues[T](blockId: BlockId, classTag: ClassTag[T]): Iterator[T] = {
val bytes = readBytes(blockId)

serializerManager
.dataDeserializeStream(blockId, bytes.toInputStream(dispose = true))(classTag)
}

private[storage] def readBytes(blockId: BlockId): ChunkedByteBuffer = {
val file = diskManager.getFile(blockId.name)
val channel = new RandomAccessFile(file, "r").getChannel
Utils.tryWithSafeFinally {
Expand Down
Expand Up @@ -160,6 +160,32 @@ private[spark] class MemoryStore(
}
}

/**
* Attempt to put the given block in memory store as values.
*
* It's possible that it is too large to materialize and store in memory. To avoid
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
* whether there is enough free memory. If the block is successfully materialized, then the
* temporary unroll memory used during the materialization is "transferred" to storage memory,
* so we won't acquire more memory than is actually needed to store the block.
*
* @return in case of success, the estimated size of the stored data. In case of failure, return
* an iterator containing the values of the block. The returned iterator will be backed
* by the combination of the partially-unrolled block and the remaining elements of the
* original input iterator. The caller must either fully consume this iterator or call
* `close()` on it in order to free the storage memory consumed by the partially-unrolled
* block.
*/
def putBytesAsValues[T](
blockId: BlockId,
bytes: ChunkedByteBuffer,
classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {
val values = serializerManager.dataDeserializeStream(blockId, bytes.toInputStream(),
maybeEncrypted = false)(classTag)

putIteratorAsValues(blockId, values, classTag)
}

/**
* Attempt to put the given block in memory store as values.
*
Expand Down Expand Up @@ -344,7 +370,7 @@ private[spark] class MemoryStore(
val serializationStream: SerializationStream = {
val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = serializerManager.getSerializer(classTag, autoPick).newInstance()
ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream))
ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
}

// Request enough memory to begin unrolling
Expand Down Expand Up @@ -820,9 +846,10 @@ private[storage] class PartiallySerializedBlock[T](
verifyNotConsumedAndNotDiscarded()
consumed = true
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
val _os = serializerManager.wrapForEncryption(os)
ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), _os)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
redirectableOutputStream.setOutputStream(_os)
while (rest.hasNext) {
serializationStream.writeObject(rest.next())(classTag)
}
Expand All @@ -844,7 +871,7 @@ private[storage] class PartiallySerializedBlock[T](
serializationStream.close()
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
val unrolledIter = serializerManager.dataDeserializeStream(
blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
blockId, unrolledBuffer.toInputStream(dispose = true), maybeEncrypted = false)(classTag)
// The unroll memory will be freed once `unrolledIter` is fully consumed in
// PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
// extra unroll memory will automatically be freed by a `finally` block in `Task`.
Expand Down
21 changes: 11 additions & 10 deletions core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
Expand Up @@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{KryoSerializer, SerializerManager}
import org.apache.spark.util.io.ChunkedByteBuffer
import org.apache.spark.util.Utils

Expand All @@ -39,27 +40,27 @@ class DiskStoreSuite extends SparkFunSuite {
val blockId = BlockId("rdd_1_2")
val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true)

val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager)
val conf = new SparkConf()
val serializer = new KryoSerializer(conf)
val serializerManager = new SerializerManager(serializer, conf)

conf.set(confKey, "0")
val diskStoreMapped = new DiskStore(conf, serializerManager, diskBlockManager)
diskStoreMapped.putBytes(blockId, byteBuffer)
val mapped = diskStoreMapped.getBytes(blockId)
val mapped = diskStoreMapped.readBytes(blockId)
assert(diskStoreMapped.remove(blockId))

val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager)
conf.set(confKey, "1m")
val diskStoreNotMapped = new DiskStore(conf, serializerManager, diskBlockManager)
diskStoreNotMapped.putBytes(blockId, byteBuffer)
val notMapped = diskStoreNotMapped.getBytes(blockId)
val notMapped = diskStoreNotMapped.readBytes(blockId)

// Not possible to do isInstanceOf due to visibility of HeapByteBuffer
assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")),
"Expected HeapByteBuffer for un-mapped read")
assert(mapped.getChunks().forall(_.isInstanceOf[MappedByteBuffer]),
"Expected MappedByteBuffer for mapped read")

def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = {
val array = new Array[Byte](in.remaining())
in.get(array)
array
}

assert(Arrays.equals(mapped.toArray, bytes))
assert(Arrays.equals(notMapped.toArray, bytes))
}
Expand Down

0 comments on commit f9a91d6

Please sign in to comment.