Skip to content

Commit

Permalink
Add notes + tests for maximum record / page sizes.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 10, 2015
1 parent 9d1ee7c commit fcd9a3c
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 40 deletions.
Expand Up @@ -19,9 +19,24 @@

/**
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
* <p>
* Within the long, the data is laid out as follows:
* <pre>
* [24 bit partition number][13 bit memory page number][27 bit offset in page]
* </pre>
* This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
* our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
* 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
* implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
* <p>
* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
* optimization to future work as it will require more careful design to ensure that addresses are
* properly aligned (e.g. by padding records).
*/
final class PackedRecordPointer {

static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes

/** Bit mask for the lower 40 bits of a long. */
private static final long MASK_LONG_LOWER_40_BITS = 0xFFFFFFFFFFL;

Expand Down Expand Up @@ -55,7 +70,11 @@ public static long packPointer(long recordPointer, int partitionId) {
return (((long) partitionId) << 40) | compressedAddress;
}

public long packedRecordPointer;
private long packedRecordPointer;

public void set(long packedRecordPointer) {
this.packedRecordPointer = packedRecordPointer;
}

public int getPartitionId() {
return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
Expand All @@ -68,7 +87,4 @@ public long getRecordPointer() {
return pageNumber | offsetInPage;
}

public int getRecordLength() {
return -1; // TODO
}
}
Expand Up @@ -57,8 +57,9 @@ final class UnsafeShuffleExternalSorter {

private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);

private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this / don't duplicate
private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this
@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;

private final int initialSize;
private final int numPartitions;
Expand Down Expand Up @@ -88,13 +89,13 @@ final class UnsafeShuffleExternalSorter {
private long freeSpaceInCurrentPage = 0;

public UnsafeShuffleExternalSorter(
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
int initialSize,
int numPartitions,
SparkConf conf) throws IOException {
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
int initialSize,
int numPartitions,
SparkConf conf) throws IOException {
this.memoryManager = memoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
Expand Down Expand Up @@ -140,8 +141,9 @@ private SpillInfo writeSpillFile() throws IOException {

// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
// records in a byte array. This array only needs to be big enough to hold a single record.
final byte[] arr = new byte[SER_BUFFER_SIZE];
// data through a byte array. This array does not need to be large enough to hold a single
// record;
final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];

// Because this output will be read during shuffle, its compression codec must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
Expand Down Expand Up @@ -186,16 +188,23 @@ private SpillInfo writeSpillFile() throws IOException {
}

final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
final int recordLength = PlatformDependent.UNSAFE.getInt(
memoryManager.getPage(recordPointer), memoryManager.getOffsetInPage(recordPointer));
PlatformDependent.copyMemory(
memoryManager.getPage(recordPointer),
memoryManager.getOffsetInPage(recordPointer) + 4, // skip over record length
arr,
PlatformDependent.BYTE_ARRAY_OFFSET,
recordLength);
assert (writer != null); // To suppress an IntelliJ warning
writer.write(arr, 0, recordLength);
final Object recordPage = memoryManager.getPage(recordPointer);
final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
long recordReadPosition = recordOffsetInPage + 4; // skip over record length
while (dataRemaining > 0) {
final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
PlatformDependent.copyMemory(
recordPage,
recordReadPosition,
writeBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
toTransfer);
assert (writer != null); // To suppress an IntelliJ warning
writer.write(writeBuffer, 0, toTransfer);
recordReadPosition += toTransfer;
dataRemaining -= toTransfer;
}
// TODO: add a test that detects whether we leave this call out:
writer.recordWritten();
}
Expand Down
Expand Up @@ -38,7 +38,7 @@ public PackedRecordPointer newKey() {

@Override
public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
reuse.packedRecordPointer = data[pos];
reuse.set(data[pos]);
return reuse;
}

Expand Down
Expand Up @@ -95,7 +95,7 @@ public boolean hasNext() {

@Override
public void loadNext() {
packedRecordPointer.packedRecordPointer = sortBuffer[position];
packedRecordPointer.set(sortBuffer[position]);
position++;
}
};
Expand Down
Expand Up @@ -54,7 +54,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {

private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);

private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
@VisibleForTesting
static final int MAXIMUM_RECORD_SIZE = 1024 * 1024 * 64; // 64 megabytes
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();

private final BlockManager blockManager;
Expand Down Expand Up @@ -108,19 +109,26 @@ public UnsafeShuffleWriter(
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
}

public void write(Iterator<Product2<K, V>> records) {
public void write(Iterator<Product2<K, V>> records) throws IOException {
write(JavaConversions.asScalaIterator(records));
}

@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) {
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
try {
while (records.hasNext()) {
insertRecordIntoSorter(records.next());
}
closeAndWriteOutput();
} catch (Exception e) {
PlatformDependent.throwException(e);
// Unfortunately, we have to catch Exception here in order to ensure proper cleanup after
// errors becuase Spark's Scala code, or users' custom Serializers, might throw arbitrary
// unchecked exceptions.
try {
sorter.cleanupAfterError();
} finally {
throw new IOException("Error during shuffle write", e);
}
}
}

Expand All @@ -134,7 +142,7 @@ private void open() throws IOException {
4096, // Initial size (TODO: tune this!)
partitioner.numPartitions(),
sparkConf);
serArray = new byte[SER_BUFFER_SIZE];
serArray = new byte[MAXIMUM_RECORD_SIZE];
serByteBuffer = ByteBuffer.wrap(serArray);
serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
}
Expand Down
Expand Up @@ -17,13 +17,16 @@

package org.apache.spark.shuffle

import java.io.IOException

import org.apache.spark.scheduler.MapStatus

/**
* Obtained inside a map task to write out records to the shuffle system.
*/
private[spark] abstract class ShuffleWriter[K, V] {
/** Write a sequence of records to this task's output */
@throws[IOException]
def write(records: Iterator[Product2[K, V]]): Unit

/** Close this writer, passing along whether the map completed */
Expand Down
Expand Up @@ -34,10 +34,10 @@ public void heap() {
final MemoryBlock page0 = memoryManager.allocatePage(100);
final MemoryBlock page1 = memoryManager.allocatePage(100);
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42);
PackedRecordPointer packedPointerWrapper = new PackedRecordPointer();
packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360);
Assert.assertEquals(360, packedPointerWrapper.getPartitionId());
Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer());
PackedRecordPointer packedPointer = new PackedRecordPointer();
packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
Assert.assertEquals(360, packedPointer.getPartitionId());
Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer());
memoryManager.cleanUpAllAllocatedMemory();
}

Expand All @@ -48,10 +48,10 @@ public void offHeap() {
final MemoryBlock page0 = memoryManager.allocatePage(100);
final MemoryBlock page1 = memoryManager.allocatePage(100);
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42);
PackedRecordPointer packedPointerWrapper = new PackedRecordPointer();
packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360);
Assert.assertEquals(360, packedPointerWrapper.getPartitionId());
Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer());
PackedRecordPointer packedPointer = new PackedRecordPointer();
packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
Assert.assertEquals(360, packedPointer.getPartitionId());
Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer());
memoryManager.cleanUpAllAllocatedMemory();
}
}
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.shuffle.unsafe;

import java.io.*;
import java.nio.ByteBuffer;
import java.util.*;

import scala.*;
Expand Down Expand Up @@ -287,6 +288,42 @@ public void mergeSpillsWithFileStream() throws Exception {
testMergingSpills(false);
}

@Test
public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite =
new ArrayList<Product2<Object, Object>>();
final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
new Random(42).nextBytes(bytes);
dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
writer.write(dataToWrite.iterator());
writer.stop(true);
Assert.assertEquals(
HashMultiset.create(dataToWrite),
HashMultiset.create(readRecordsFromFile()));
assertSpillFilesWereCleanedUp();
}

@Test
public void writeRecordsThatAreBiggerThanMaximumRecordSize() throws Exception {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite =
new ArrayList<Product2<Object, Object>>();
final byte[] bytes = new byte[UnsafeShuffleWriter.MAXIMUM_RECORD_SIZE * 2];
new Random(42).nextBytes(bytes);
dataToWrite.add(new Tuple2<Object, Object>(1, bytes));
try {
// Insert a record and force a spill so that there's something to clean up:
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
writer.forceSorterToSpill();
writer.write(dataToWrite.iterator());
Assert.fail("Expected exception to be thrown");
} catch (IOException e) {
// Pass
}
assertSpillFilesWereCleanedUp();
}

@Test
public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
Expand Down

0 comments on commit fcd9a3c

Please sign in to comment.