Skip to content

Commit

Permalink
That for inserting records AT the max record size.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 10, 2015
1 parent fcd9a3c commit 27b18b0
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ final class UnsafeShuffleExternalSorter {

private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);

private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
@VisibleForTesting
static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;

private final int initialSize;
private final int numPartitions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.shuffle.unsafe;

import java.io.*;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.Iterator;

Expand Down Expand Up @@ -73,8 +72,14 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {

private MapStatus mapStatus = null;
private UnsafeShuffleExternalSorter sorter = null;
private byte[] serArray = null;
private ByteBuffer serByteBuffer;

/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
public MyByteArrayOutputStream(int size) { super(size); }
public byte[] getBuf() { return buf; }
}

private MyByteArrayOutputStream serBuffer;
private SerializationStream serOutputStream;

/**
Expand Down Expand Up @@ -142,18 +147,16 @@ private void open() throws IOException {
4096, // Initial size (TODO: tune this!)
partitioner.numPartitions(),
sparkConf);
serArray = new byte[MAXIMUM_RECORD_SIZE];
serByteBuffer = ByteBuffer.wrap(serArray);
serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
serBuffer = new MyByteArrayOutputStream(1024 * 1024);
serOutputStream = serializer.serializeStream(serBuffer);
}

@VisibleForTesting
void closeAndWriteOutput() throws IOException {
if (sorter == null) {
open();
}
serArray = null;
serByteBuffer = null;
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
Expand All @@ -178,16 +181,16 @@ void insertRecordIntoSorter(Product2<K, V> record) throws IOException{
}
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serByteBuffer.position(0);
serBuffer.reset();
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
serOutputStream.flush();

final int serializedRecordSize = serByteBuffer.position();
final int serializedRecordSize = serBuffer.size();
assert (serializedRecordSize > 0);

sorter.insertRecord(
serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import scala.*;
import scala.collection.Iterator;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;

import com.google.common.collect.HashMultiset;
Expand All @@ -44,11 +45,8 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.serializer.*;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.DeserializationStream;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
Expand Down Expand Up @@ -305,18 +303,59 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception
}

@Test
public void writeRecordsThatAreBiggerThanMaximumRecordSize() throws Exception {
public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
// Use a custom serializer so that we have exact control over the size of serialized data.
final Serializer byteArraySerializer = new Serializer() {
@Override
public SerializerInstance newInstance() {
return new SerializerInstance() {
@Override
public SerializationStream serializeStream(final OutputStream s) {
return new SerializationStream() {
@Override
public void flush() { }

@Override
public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
byte[] bytes = (byte[]) t;
try {
s.write(bytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}

@Override
public void close() { }
};
}
public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) { return null; }
public DeserializationStream deserializeStream(InputStream s) { return null; }
public <T> T deserialize(ByteBuffer b, ClassLoader l, ClassTag<T> ev1) { return null; }
public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) { return null; }
};
}
};
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(byteArraySerializer));
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite =
new ArrayList<Product2<Object, Object>>();
final byte[] bytes = new byte[UnsafeShuffleWriter.MAXIMUM_RECORD_SIZE * 2];
new Random(42).nextBytes(bytes);
dataToWrite.add(new Tuple2<Object, Object>(1, bytes));
// Insert a record and force a spill so that there's something to clean up:
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[1], new byte[1]));
writer.forceSorterToSpill();
// We should be able to write a record that's right _at_ the max record size
final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE];
new Random(42).nextBytes(atMaxRecordSize);
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[0], atMaxRecordSize));
writer.forceSorterToSpill();
// Inserting a record that's larger than the max record size should fail:
final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1];
new Random(42).nextBytes(exceedsMaxRecordSize);
Product2<Object, Object> hugeRecord =
new Tuple2<Object, Object>(new byte[0], exceedsMaxRecordSize);
try {
// Insert a record and force a spill so that there's something to clean up:
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
writer.forceSorterToSpill();
writer.write(dataToWrite.iterator());
// Here, we write through the public `write()` interface instead of the test-only
// `insertRecordIntoSorter` interface:
writer.write(Collections.singletonList(hugeRecord).iterator());
Assert.fail("Expected exception to be thrown");
} catch (IOException e) {
// Pass
Expand Down

0 comments on commit 27b18b0

Please sign in to comment.