Skip to content

Commit

Permalink
Fixes to multiple spilling-related bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 7, 2015
1 parent 82e21c1 commit 8d7fbe7
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,15 @@ public void spill() throws IOException {
spillWriters.size() > 1 ? " times" : " time");

final UnsafeSorterSpillWriter spillWriter =
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics);
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
sorter.numRecords());
spillWriters.add(spillWriter);
final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final Object baseObject = sortedRecords.getBaseObject();
final long baseOffset = sortedRecords.getBaseOffset();
// TODO: this assumption that the first long holds a length is not enforced via our interfaces
// We need to either always store this via the write path (e.g. not require the caller to do
// it), or provide interfaces / hooks for customizing the physical storage format etc.
final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
final int recordLength = sortedRecords.getRecordLength();
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
}
spillWriter.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ public UnsafeInMemorySorter(
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
}

/**
* @return the number of records that have been inserted into this sorter.
*/
public int numRecords() {
return pointerArrayInsertPosition / 2;
}

public long getMemoryUsage() {
return pointerArray.length * 8L;
}
Expand All @@ -106,7 +113,8 @@ public void expandPointerArray() {
}

/**
* Inserts a record to be sorted.
* Inserts a record to be sorted. Assumes that the record pointer points to a record length
* stored as a 4-byte integer, followed by the record's bytes.
*
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
* @param keyPrefix a user-defined key prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,47 @@
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.PlatformDependent;

/**
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
* of the file format).
*/
final class UnsafeSorterSpillReader extends UnsafeSorterIterator {

private final File file;
private InputStream in;
private DataInputStream din;

private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)?
private int nextRecordLength;

// Variables that change with every record read:
private int recordLength;
private long keyPrefix;
private int numRecordsRemaining;

private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)?
private final Object baseObject = arr;
private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;

public UnsafeSorterSpillReader(
BlockManager blockManager,
File file,
BlockId blockId) throws IOException {
this.file = file;
assert (file.length() > 0);
assert (file.length() > 0);
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
this.in = blockManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
nextRecordLength = din.readInt();
numRecordsRemaining = din.readInt();
}

@Override
public boolean hasNext() {
return (in != null);
return (numRecordsRemaining > 0);
}

@Override
public void loadNext() throws IOException {
recordLength = din.readInt();
keyPrefix = din.readLong();
ByteStreams.readFully(in, arr, 0, nextRecordLength);
nextRecordLength = din.readInt();
if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) {
ByteStreams.readFully(in, arr, 0, recordLength);
numRecordsRemaining--;
if (numRecordsRemaining == 0) {
in.close();
in = null;
din = null;
Expand All @@ -79,7 +84,7 @@ public long getBaseOffset() {

@Override
public int getRecordLength() {
return 0;
return recordLength;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
import org.apache.spark.storage.TempLocalBlockId;
import org.apache.spark.unsafe.PlatformDependent;

/**
* Spills a list of sorted records to disk. Spill files have the following format:
*
* [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
*/
final class UnsafeSorterSpillWriter {

static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
static final int EOF_MARKER = -1;

// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
Expand All @@ -42,22 +46,29 @@ final class UnsafeSorterSpillWriter {

private final File file;
private final BlockId blockId;
private final int numRecordsToWrite;
private BlockObjectWriter writer;
private int numRecordsSpilled = 0;

public UnsafeSorterSpillWriter(
BlockManager blockManager,
int fileBufferSize,
ShuffleWriteMetrics writeMetrics) {
ShuffleWriteMetrics writeMetrics,
int numRecordsToWrite) throws IOException {
final Tuple2<TempLocalBlockId, File> spilledFileInfo =
blockManager.diskBlockManager().createTempLocalBlock();
this.file = spilledFileInfo._2();
this.blockId = spilledFileInfo._1();
this.numRecordsToWrite = numRecordsToWrite;
// Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
// Our write path doesn't actually use this serializer (since we end up calling the `write()`
// OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
// around this, we pass a dummy no-op serializer.
writer = blockManager.getDiskWriter(
blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics);
// Write the number of records
writeIntToBuffer(numRecordsToWrite, 0);
writer.write(writeBuffer, 0, 4);
}

// Based on DataOutputStream.writeLong.
Expand Down Expand Up @@ -85,6 +96,12 @@ public void write(
long baseOffset,
int recordLength,
long keyPrefix) throws IOException {
if (numRecordsSpilled == numRecordsToWrite) {
throw new IllegalStateException(
"Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite);
} else {
numRecordsSpilled++;
}
writeIntToBuffer(recordLength, 0);
writeLongToBuffer(keyPrefix, 4);
int dataRemaining = recordLength;
Expand All @@ -107,8 +124,6 @@ public void write(
}

public void close() throws IOException {
writeIntToBuffer(EOF_MARKER, 0);
writer.write(writeBuffer, 0, 4);
writer.commitAndClose();
writer = null;
writeBuffer = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,17 @@ public void testSortingOnlyByPrefix() throws Exception {
insertNumber(sorter, 3);
sorter.spill();
insertNumber(sorter, 4);
sorter.spill();
insertNumber(sorter, 2);

UnsafeSorterIterator iter = sorter.getSortedIterator();

iter.loadNext();
assertEquals(1, iter.getKeyPrefix());
iter.loadNext();
assertEquals(2, iter.getKeyPrefix());
iter.loadNext();
assertEquals(3, iter.getKeyPrefix());
iter.loadNext();
assertEquals(4, iter.getKeyPrefix());
iter.loadNext();
assertEquals(5, iter.getKeyPrefix());
assertFalse(iter.hasNext());
// TODO: check that the values are also read back properly.
for (int i = 1; i <= 5; i++) {
iter.loadNext();
assertEquals(i, iter.getKeyPrefix());
assertEquals(4, iter.getRecordLength());
// TODO: read rest of value.
}

// TODO: test for cleanup:
// assert(tempDir.isEmpty)
Expand Down

0 comments on commit 8d7fbe7

Please sign in to comment.