From 27b18b09aca66cb2dac8f779701569200deba43a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 14:37:53 -0700 Subject: [PATCH] That for inserting records AT the max record size. --- .../unsafe/UnsafeShuffleExternalSorter.java | 4 +- .../shuffle/unsafe/UnsafeShuffleWriter.java | 25 ++++--- .../unsafe/UnsafeShuffleWriterSuite.java | 67 +++++++++++++++---- 3 files changed, 70 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 6e0d8da410231..c9d818034c899 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -57,9 +57,11 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; private final int initialSize; private final int numPartitions; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index db9f8648a93b4..5bf04617854bb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -18,7 +18,6 @@ package org.apache.spark.shuffle.unsafe; import java.io.*; -import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.Iterator; @@ -73,8 +72,14 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private MapStatus mapStatus = null; private UnsafeShuffleExternalSorter sorter = null; - private byte[] serArray = null; - private ByteBuffer serByteBuffer; + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { + public MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } + } + + private MyByteArrayOutputStream serBuffer; private SerializationStream serOutputStream; /** @@ -142,9 +147,8 @@ private void open() throws IOException { 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), sparkConf); - serArray = new byte[MAXIMUM_RECORD_SIZE]; - serByteBuffer = ByteBuffer.wrap(serArray); - serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); + serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serOutputStream = serializer.serializeStream(serBuffer); } @VisibleForTesting @@ -152,8 +156,7 @@ void closeAndWriteOutput() throws IOException { if (sorter == null) { open(); } - serArray = null; - serByteBuffer = null; + serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; @@ -178,16 +181,16 @@ void insertRecordIntoSorter(Product2 record) throws IOException{ } final K key = record._1(); final int partitionId = partitioner.getPartition(key); - serByteBuffer.position(0); + serBuffer.reset(); serOutputStream.writeKey(key, OBJECT_CLASS_TAG); serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); serOutputStream.flush(); - final int serializedRecordSize = serByteBuffer.position(); + final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); sorter.insertRecord( - serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 9002126bb7a4a..48ba85f917b87 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -23,6 +23,7 @@ import scala.*; import scala.collection.Iterator; +import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; import com.google.common.collect.HashMultiset; @@ -44,11 +45,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.serializer.*; import org.apache.spark.scheduler.MapStatus; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.KryoSerializer; -import org.apache.spark.serializer.Serializer; -import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; @@ -305,18 +303,59 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception } @Test - public void writeRecordsThatAreBiggerThanMaximumRecordSize() throws Exception { + public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { + // Use a custom serializer so that we have exact control over the size of serialized data. + final Serializer byteArraySerializer = new Serializer() { + @Override + public SerializerInstance newInstance() { + return new SerializerInstance() { + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + byte[] bytes = (byte[]) t; + try { + s.write(bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + @Override + public void close() { } + }; + } + public ByteBuffer serialize(T t, ClassTag ev1) { return null; } + public DeserializationStream deserializeStream(InputStream s) { return null; } + public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; } + public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; } + }; + } + }; + when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer)); final UnsafeShuffleWriter writer = createWriter(false); - final ArrayList> dataToWrite = - new ArrayList>(); - final byte[] bytes = new byte[UnsafeShuffleWriter.MAXIMUM_RECORD_SIZE * 2]; - new Random(42).nextBytes(bytes); - dataToWrite.add(new Tuple2(1, bytes)); + // Insert a record and force a spill so that there's something to clean up: + writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); + writer.forceSorterToSpill(); + // We should be able to write a record that's right _at_ the max record size + final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + new Random(42).nextBytes(atMaxRecordSize); + writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); + writer.forceSorterToSpill(); + // Inserting a record that's larger than the max record size should fail: + final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + new Random(42).nextBytes(exceedsMaxRecordSize); + Product2 hugeRecord = + new Tuple2(new byte[0], exceedsMaxRecordSize); try { - // Insert a record and force a spill so that there's something to clean up: - writer.insertRecordIntoSorter(new Tuple2(1, 1)); - writer.forceSorterToSpill(); - writer.write(dataToWrite.iterator()); + // Here, we write through the public `write()` interface instead of the test-only + // `insertRecordIntoSorter` interface: + writer.write(Collections.singletonList(hugeRecord).iterator()); Assert.fail("Expected exception to be thrown"); } catch (IOException e) { // Pass