Skip to content

Commit

Permalink
[SPARK-18546][CORE] Fix merging shuffle spills when using encryption.
Browse files Browse the repository at this point in the history
The problem exists because it's not possible to just concatenate encrypted
partition data from different spill files; currently each partition would
have its own initial vector to set up encryption, and the final merged file
should contain a single initial vector for each merged partiton, otherwise
iterating over each record becomes really hard.

To fix that, UnsafeShuffleWriter now decrypts the partitions when merging,
so that the merged file contains a single initial vector at the start of
the partition data.

Because it's not possible to do that using the fast transferTo path, when
encryption is enabled UnsafeShuffleWriter will revert back to using file
streams when merging. It may be possible to use a hybrid approach when
using encryption, using an intermediate direct buffer when reading from
files and encrypting the data, but that's better left for a separate patch.

As part of the change I made DiskBlockObjectWriter take a SerializerManager
instead of a "wrap stream" closure, since that makes it easier to test the
code without having to mock SerializerManager functionality.

Tested with newly added unit tests (UnsafeShuffleWriterSuite for the write
side and ExternalAppendOnlyMapSuite for integration), and by running some
apps that failed without the fix.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #15982 from vanzin/SPARK-18546.

(cherry picked from commit 93e9d88)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
  • Loading branch information
Marcelo Vanzin committed Nov 30, 2016
1 parent 9e96ac5 commit c2c2fdc
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.io.output.CountingOutputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
Expand Down Expand Up @@ -264,6 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
final boolean fastMergeIsSupported = !compressionEnabled ||
CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
Expand All @@ -289,7 +292,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// Compression is disabled or we are using an IO compression codec that supports
// decompression of concatenated compressed streams, so we can perform a fast spill merge
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled) {
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
} else {
Expand Down Expand Up @@ -320,9 +323,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
/**
* Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
* {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
* cases where the IO compression codec does not support concatenation of compressed data, or in
* cases where users have explicitly disabled use of {@code transferTo} in order to work around
* kernel bugs.
* cases where the IO compression codec does not support concatenation of compressed data, when
* encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in
* order to work around kernel bugs.
*
* @param spills the spills to merge.
* @param outputFile the file to write the merged data to.
Expand All @@ -337,42 +340,47 @@ private long[] mergeSpillsWithFileStream(
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final InputStream[] spillInputStreams = new FileInputStream[spills.length];
OutputStream mergedFileOutputStream = null;

// Use a counting output stream to avoid having to close the underlying file and ask
// the file system for its size after each partition is written.
final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(
new FileOutputStream(outputFile));

boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new FileInputStream(spills[i].file);
}
for (int partition = 0; partition < numPartitions; partition++) {
final long initialFileLength = outputFile.length();
mergedFileOutputStream =
new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
final long initialFileLength = mergedFileOutputStream.getByteCount();
// Shield the underlying output stream from close() calls, so that we can close the higher
// level streams to make sure all data is really flushed and internal state is cleaned.
OutputStream partitionOutput = new CloseShieldOutputStream(
new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
if (compressionCodec != null) {
mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
}

for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
if (partitionLengthInSpill > 0) {
InputStream partitionInputStream = null;
boolean innerThrewException = true;
InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
partitionLengthInSpill, false);
try {
partitionInputStream =
new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
partitionInputStream = blockManager.serializerManager().wrapForEncryption(
partitionInputStream);
if (compressionCodec != null) {
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
}
ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
innerThrewException = false;
ByteStreams.copy(partitionInputStream, partitionOutput);
} finally {
Closeables.close(partitionInputStream, innerThrewException);
partitionInputStream.close();
}
}
}
mergedFileOutputStream.flush();
mergedFileOutputStream.close();
partitionLengths[partition] = (outputFile.length() - initialFileLength);
partitionOutput.flush();
partitionOutput.close();
partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
}
threwException = false;
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ private[spark] class SerializerManager(
* loaded yet. */
private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)

def encryptionEnabled: Boolean = encryptionKey.isDefined

def canUseKryo(ct: ClassTag[_]): Boolean = {
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
}
Expand Down Expand Up @@ -129,7 +131,7 @@ private[spark] class SerializerManager(
/**
* Wrap an input stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: InputStream): InputStream = {
def wrapForEncryption(s: InputStream): InputStream = {
encryptionKey
.map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
.getOrElse(s)
Expand All @@ -138,7 +140,7 @@ private[spark] class SerializerManager(
/**
* Wrap an output stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
def wrapForEncryption(s: OutputStream): OutputStream = {
encryptionKey
.map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
.getOrElse(s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ private[spark] class BlockManager(
executorId: String,
rpcEnv: RpcEnv,
val master: BlockManagerMaster,
serializerManager: SerializerManager,
val serializerManager: SerializerManager,
val conf: SparkConf,
memoryManager: MemoryManager,
mapOutputTracker: MapOutputTracker,
Expand Down Expand Up @@ -745,9 +745,8 @@ private[spark] class BlockManager(
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
syncWrites, writeMetrics, blockId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.channels.FileChannel

import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.util.Utils

/**
Expand All @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils
*/
private[spark] class DiskBlockObjectWriter(
val file: File,
serializerManager: SerializerManager,
serializerInstance: SerializerInstance,
bufferSize: Int,
wrapStream: OutputStream => OutputStream,
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
Expand Down Expand Up @@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter(
initialized = true
}

bs = wrapStream(mcs)
bs = serializerManager.wrapStream(blockId, mcs)
objOut = serializerInstance.serializeStream(bs)
streamOpen = true
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@
import scala.Tuple2;
import scala.Tuple2$;
import scala.collection.Iterator;
import scala.runtime.AbstractFunction1;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Iterators;
import com.google.common.io.ByteStreams;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -53,6 +51,7 @@
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.serializer.*;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.storage.*;
Expand All @@ -77,7 +76,6 @@ public class UnsafeShuffleWriterSuite {
final LinkedList<File> spillFilesCreated = new LinkedList<>();
SparkConf conf;
final Serializer serializer = new KryoSerializer(new SparkConf());
final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf());
TaskMetrics taskMetrics;

@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
Expand All @@ -86,17 +84,6 @@ public class UnsafeShuffleWriterSuite {
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
@Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;

private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
if (conf.getBoolean("spark.shuffle.compress", true)) {
return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
} else {
return stream;
}
}
}

@After
public void tearDown() {
Utils.deleteRecursively(tempDir);
Expand All @@ -121,6 +108,11 @@ public void setUp() throws IOException {
memoryManager = new TestMemoryManager(conf);
taskMemoryManager = new TaskMemoryManager(memoryManager, 0);

// Some tests will override this manager because they change the configuration. This is a
// default for tests that don't need a specific one.
SerializerManager manager = new SerializerManager(serializer, conf);
when(blockManager.serializerManager()).thenReturn(manager);

when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(blockManager.getDiskWriter(
any(BlockId.class),
Expand All @@ -131,12 +123,11 @@ public void setUp() throws IOException {
@Override
public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
Object[] args = invocationOnMock.getArguments();

return new DiskBlockObjectWriter(
(File) args[1],
blockManager.serializerManager(),
(SerializerInstance) args[2],
(Integer) args[3],
new WrapStream(),
false,
(ShuffleWriteMetrics) args[4],
(BlockId) args[0]
Expand Down Expand Up @@ -201,9 +192,10 @@ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
for (int i = 0; i < NUM_PARTITITONS; i++) {
final long partitionSize = partitionSizesInMergedFile[i];
if (partitionSize > 0) {
InputStream in = new FileInputStream(mergedOutputFile);
ByteStreams.skipFully(in, startOffset);
in = new LimitedInputStream(in, partitionSize);
FileInputStream fin = new FileInputStream(mergedOutputFile);
fin.getChannel().position(startOffset);
InputStream in = new LimitedInputStream(fin, partitionSize);
in = blockManager.serializerManager().wrapForEncryption(in);
if (conf.getBoolean("spark.shuffle.compress", true)) {
in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
}
Expand Down Expand Up @@ -294,14 +286,32 @@ public void writeWithoutSpilling() throws Exception {
}

private void testMergingSpills(
boolean transferToEnabled,
String compressionCodecName) throws IOException {
final boolean transferToEnabled,
String compressionCodecName,
boolean encrypt) throws Exception {
if (compressionCodecName != null) {
conf.set("spark.shuffle.compress", "true");
conf.set("spark.io.compression.codec", compressionCodecName);
} else {
conf.set("spark.shuffle.compress", "false");
}
conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt);

SerializerManager manager;
if (encrypt) {
manager = new SerializerManager(serializer, conf,
Option.apply(CryptoStreamUtils.createKey(conf)));
} else {
manager = new SerializerManager(serializer, conf);
}

when(blockManager.serializerManager()).thenReturn(manager);
testMergingSpills(transferToEnabled, encrypt);
}

private void testMergingSpills(
boolean transferToEnabled,
boolean encrypted) throws IOException {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
Expand All @@ -324,6 +334,7 @@ private void testMergingSpills(
for (long size: partitionSizesInMergedFile) {
sumOfPartitionSizes += size;
}

assertEquals(sumOfPartitionSizes, mergedOutputFile.length());

assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
Expand All @@ -338,42 +349,72 @@ private void testMergingSpills(

@Test
public void mergeSpillsWithTransferToAndLZF() throws Exception {
testMergingSpills(true, LZFCompressionCodec.class.getName());
testMergingSpills(true, LZFCompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithFileStreamAndLZF() throws Exception {
testMergingSpills(false, LZFCompressionCodec.class.getName());
testMergingSpills(false, LZFCompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithTransferToAndLZ4() throws Exception {
testMergingSpills(true, LZ4CompressionCodec.class.getName());
testMergingSpills(true, LZ4CompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
testMergingSpills(false, LZ4CompressionCodec.class.getName());
testMergingSpills(false, LZ4CompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithTransferToAndSnappy() throws Exception {
testMergingSpills(true, SnappyCompressionCodec.class.getName());
testMergingSpills(true, SnappyCompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
testMergingSpills(false, SnappyCompressionCodec.class.getName());
testMergingSpills(false, SnappyCompressionCodec.class.getName(), false);
}

@Test
public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
testMergingSpills(true, null);
testMergingSpills(true, null, false);
}

@Test
public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
testMergingSpills(false, null);
testMergingSpills(false, null, false);
}

@Test
public void mergeSpillsWithCompressionAndEncryption() throws Exception {
// This should actually be translated to a "file stream merge" internally, just have the
// test to make sure that it's the case.
testMergingSpills(true, LZ4CompressionCodec.class.getName(), true);
}

@Test
public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception {
testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
}

@Test
public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false");
testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
}

@Test
public void mergeSpillsWithEncryptionAndNoCompression() throws Exception {
// This should actually be translated to a "file stream merge" internally, just have the
// test to make sure that it's the case.
testMergingSpills(true, null, true);
}

@Test
public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception {
testMergingSpills(false, null, true);
}

@Test
Expand Down Expand Up @@ -531,4 +572,5 @@ public void testPeakMemoryUsed() throws Exception {
writer.stop(false);
}
}

}

0 comments on commit c2c2fdc

Please sign in to comment.