Skip to content

Commit

Permalink
More refactoring and cleanup; begin cleaning iterator interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 4, 2015
1 parent 3490512 commit 3aeaff7
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@
import org.apache.spark.storage.ShuffleBlockId;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.unsafe.sort.ExternalSorterIterator;
import org.apache.spark.unsafe.sort.UnsafeSorterIterator;
import org.apache.spark.unsafe.sort.UnsafeExternalSorter;
import static org.apache.spark.unsafe.sort.UnsafeSorter.PrefixComparator;
import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordComparator;
import org.apache.spark.unsafe.sort.PrefixComparator;

import org.apache.spark.unsafe.sort.RecordComparator;

// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles
public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
Expand Down Expand Up @@ -104,7 +105,7 @@ private void freeMemory() {
// TODO: free sorter memory
}

private ExternalSorterIterator sortRecords(
private UnsafeSorterIterator sortRecords(
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
memoryManager,
Expand Down Expand Up @@ -142,7 +143,7 @@ private ExternalSorterIterator sortRecords(
return sorter.getSortedIterator();
}

private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) throws IOException {
private long[] writeSortedRecordsToFile(UnsafeSorterIterator sortedRecords) throws IOException {
final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId);
final ShuffleBlockId blockId =
new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID());
Expand All @@ -154,7 +155,7 @@ private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) th
final byte[] arr = new byte[SER_BUFFER_SIZE];
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final int partition = (int) sortedRecords.keyPrefix;
final int partition = (int) sortedRecords.getKeyPrefix();
assert (partition >= currentPartition);
if (partition != currentPartition) {
// Switch to the new partition
Expand All @@ -168,13 +169,13 @@ private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) th
}

PlatformDependent.copyMemory(
sortedRecords.baseObject,
sortedRecords.baseOffset + 4,
sortedRecords.getBaseObject(),
sortedRecords.getBaseOffset() + 4,
arr,
PlatformDependent.BYTE_ARRAY_OFFSET,
sortedRecords.recordLength);
sortedRecords.getRecordLength());
assert (writer != null); // To suppress an IntelliJ warning
writer.write(arr, 0, sortedRecords.recordLength);
writer.write(arr, 0, sortedRecords.getRecordLength());
// TODO: add a test that detects whether we leave this call out:
writer.recordWritten();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,10 @@

package org.apache.spark.unsafe.sort;

public abstract class ExternalSorterIterator {

public Object baseObject;
public long baseOffset;
public int recordLength;
public long keyPrefix;

public abstract boolean hasNext();

public abstract void loadNext();

/**
* Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific
* comparisons, such as lexicographic comparison for strings.
*/
public abstract class PrefixComparator {
public abstract int compare(long prefix1, long prefix2);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.unsafe.sort;

/**
* Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte
* prefix, this may simply return 0.
*/
public abstract class RecordComparator {

/**
* Compare two records for order.
*
* @return a negative integer, zero, or a positive integer as the first record is less than,
* equal to, or greater than the second.
*/
public abstract int compare(
Object leftBaseObject,
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset);
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
import java.util.Iterator;
import java.util.LinkedList;

import static org.apache.spark.unsafe.sort.UnsafeSorter.*;

/**
* External sorter based on {@link UnsafeSorter}.
*/
Expand Down Expand Up @@ -111,13 +109,16 @@ public void spill() throws IOException {
final UnsafeSorterSpillWriter spillWriter =
new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics);
spillWriters.add(spillWriter);
final Iterator<RecordPointerAndKeyPrefix> sortedRecords = sorter.getSortedIterator();
final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
while (sortedRecords.hasNext()) {
final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next();
final Object baseObject = memoryManager.getPage(recordPointer.recordPointer);
final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer);
sortedRecords.loadNext();
final Object baseObject = sortedRecords.getBaseObject();
final long baseOffset = sortedRecords.getBaseOffset();
// TODO: this assumption that the first long holds a length is not enforced via our interfaces
// We need to either always store this via the write path (e.g. not require the caller to do
// it), or provide interfaces / hooks for customizing the physical storage format etc.
final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
spillWriter.write(baseObject, baseOffset, recordLength, recordPointer.keyPrefix);
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
}
spillWriter.close();
final long sorterMemoryUsage = sorter.getMemoryUsage();
Expand Down Expand Up @@ -220,14 +221,14 @@ public void insertRecord(
sorter.insertRecord(recordAddress, prefix);
}

public ExternalSorterIterator getSortedIterator() throws IOException {
public UnsafeSorterIterator getSortedIterator() throws IOException {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator);
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
spillMerger.addSpill(spillWriter.getReader(blockManager));
}
spillWriters.clear();
spillMerger.addSpill(sorter.getMergeableIterator());
spillMerger.addSpill(sorter.getSortedIterator());
return spillMerger.getSortedIterator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.unsafe.sort;

import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordPointerAndKeyPrefix;
import org.apache.spark.util.collection.SortDataFormat;
import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.RecordPointerAndKeyPrefix;

/**
* Supports sorting an array of (record pointer, key prefix) pairs. Used in {@link UnsafeSorter}.
Expand All @@ -28,6 +28,19 @@
*/
final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, long[]> {

static final class RecordPointerAndKeyPrefix {
/**
* A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
* description of how these addresses are encoded.
*/
public long recordPointer;

/**
* A key prefix, for use in comparisons.
*/
public long keyPrefix;
}

public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();

private UnsafeSortDataFormat() { }
Expand Down
93 changes: 13 additions & 80 deletions core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
package org.apache.spark.unsafe.sort;

import java.util.Comparator;
import java.util.Iterator;

import org.apache.spark.util.collection.Sorter;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.RecordPointerAndKeyPrefix;

/**
* Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
Expand All @@ -32,45 +32,6 @@
*/
public final class UnsafeSorter {

public static final class RecordPointerAndKeyPrefix {
/**
* A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
* description of how these addresses are encoded.
*/
public long recordPointer;

/**
* A key prefix, for use in comparisons.
*/
public long keyPrefix;
}

/**
* Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte
* prefix, this may simply return 0.
*/
public static abstract class RecordComparator {
/**
* Compare two records for order.
*
* @return a negative integer, zero, or a positive integer as the first record is less than,
* equal to, or greater than the second.
*/
public abstract int compare(
Object leftBaseObject,
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset);
}

/**
* Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific
* comparisons, such as lexicographic comparison for strings.
*/
public static abstract class PrefixComparator {
public abstract int compare(long prefix1, long prefix2);
}

private final TaskMemoryManager memoryManager;
private final Sorter<RecordPointerAndKeyPrefix, long[]> sorter;
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
Expand Down Expand Up @@ -148,69 +109,41 @@ public void insertRecord(long objectAddress, long keyPrefix) {
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
* {@code next()} will return the same mutable object.
*/
public Iterator<RecordPointerAndKeyPrefix> getSortedIterator() {
sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator);
return new Iterator<RecordPointerAndKeyPrefix>() {
private int position = 0;
private final RecordPointerAndKeyPrefix keyPointerAndPrefix = new RecordPointerAndKeyPrefix();

@Override
public boolean hasNext() {
return position < sortBufferInsertPosition;
}

@Override
public RecordPointerAndKeyPrefix next() {
keyPointerAndPrefix.recordPointer = sortBuffer[position];
keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1];
position += 2;
return keyPointerAndPrefix;
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}

public UnsafeSorterSpillMerger.MergeableIterator getMergeableIterator() {
public UnsafeSorterIterator getSortedIterator() {
sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator);
return new UnsafeSorterSpillMerger.MergeableIterator() {
return new UnsafeSorterIterator() {

private int position = 0;
private Object baseObject;
private long baseOffset;
private long keyPrefix;
private int recordLength;

@Override
public boolean hasNext() {
return position < sortBufferInsertPosition;
}

@Override
public void loadNextRecord() {
public void loadNext() {
final long recordPointer = sortBuffer[position];
keyPrefix = sortBuffer[position + 1];
position += 2;
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer);
keyPrefix = sortBuffer[position + 1];
position += 2;
}

@Override
public long getPrefix() {
return keyPrefix;
}
public Object getBaseObject() { return baseObject; }

@Override
public Object getBaseObject() {
return baseObject;
}
public long getBaseOffset() { return baseOffset; }

@Override
public long getBaseOffset() {
return baseOffset;
}
public int getRecordLength() { return recordLength; }

@Override
public long getKeyPrefix() { return keyPrefix; }
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.unsafe.sort;

import java.io.IOException;

public abstract class UnsafeSorterIterator {

public abstract boolean hasNext();

public abstract void loadNext() throws IOException;

public abstract Object getBaseObject();

public abstract long getBaseOffset();

public abstract int getRecordLength();

public abstract long getKeyPrefix();
}
Loading

0 comments on commit 3aeaff7

Please sign in to comment.