diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 1c875a15687c1..01fe022fad046 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -42,10 +42,11 @@ import org.apache.spark.storage.ShuffleBlockId; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.unsafe.sort.ExternalSorterIterator; +import org.apache.spark.unsafe.sort.UnsafeSorterIterator; import org.apache.spark.unsafe.sort.UnsafeExternalSorter; -import static org.apache.spark.unsafe.sort.UnsafeSorter.PrefixComparator; -import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordComparator; +import org.apache.spark.unsafe.sort.PrefixComparator; + +import org.apache.spark.unsafe.sort.RecordComparator; // IntelliJ gets confused and claims that this class should be abstract, but this actually compiles public class UnsafeShuffleWriter implements ShuffleWriter { @@ -104,7 +105,7 @@ private void freeMemory() { // TODO: free sorter memory } - private ExternalSorterIterator sortRecords( + private UnsafeSorterIterator sortRecords( scala.collection.Iterator> records) throws Exception { final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, @@ -142,7 +143,7 @@ private ExternalSorterIterator sortRecords( return sorter.getSortedIterator(); } - private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) throws IOException { + private long[] writeSortedRecordsToFile(UnsafeSorterIterator sortedRecords) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); final ShuffleBlockId blockId = new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); @@ -154,7 +155,7 @@ private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) th final byte[] arr = new byte[SER_BUFFER_SIZE]; while (sortedRecords.hasNext()) { sortedRecords.loadNext(); - final int partition = (int) sortedRecords.keyPrefix; + final int partition = (int) sortedRecords.getKeyPrefix(); assert (partition >= currentPartition); if (partition != currentPartition) { // Switch to the new partition @@ -168,13 +169,13 @@ private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) th } PlatformDependent.copyMemory( - sortedRecords.baseObject, - sortedRecords.baseOffset + 4, + sortedRecords.getBaseObject(), + sortedRecords.getBaseOffset() + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, - sortedRecords.recordLength); + sortedRecords.getRecordLength()); assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, sortedRecords.recordLength); + writer.write(arr, 0, sortedRecords.getRecordLength()); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java b/core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java similarity index 76% rename from core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java rename to core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java index d53a0baaf351f..a8468d8b1cdb9 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java @@ -17,15 +17,10 @@ package org.apache.spark.unsafe.sort; -public abstract class ExternalSorterIterator { - - public Object baseObject; - public long baseOffset; - public int recordLength; - public long keyPrefix; - - public abstract boolean hasNext(); - - public abstract void loadNext(); - +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java new file mode 100644 index 0000000000000..99a2b077f9869 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { + + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java index bf0019c51703f..f455f471e1d13 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java @@ -33,8 +33,6 @@ import java.util.Iterator; import java.util.LinkedList; -import static org.apache.spark.unsafe.sort.UnsafeSorter.*; - /** * External sorter based on {@link UnsafeSorter}. */ @@ -111,13 +109,16 @@ public void spill() throws IOException { final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); spillWriters.add(spillWriter); - final Iterator sortedRecords = sorter.getSortedIterator(); + final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); while (sortedRecords.hasNext()) { - final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); - final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); - final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + // TODO: this assumption that the first long holds a length is not enforced via our interfaces + // We need to either always store this via the write path (e.g. not require the caller to do + // it), or provide interfaces / hooks for customizing the physical storage format etc. final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); - spillWriter.write(baseObject, baseOffset, recordLength, recordPointer.keyPrefix); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); } spillWriter.close(); final long sorterMemoryUsage = sorter.getMemoryUsage(); @@ -220,14 +221,14 @@ public void insertRecord( sorter.insertRecord(recordAddress, prefix); } - public ExternalSorterIterator getSortedIterator() throws IOException { + public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(recordComparator, prefixComparator); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpill(spillWriter.getReader(blockManager)); } spillWriters.clear(); - spillMerger.addSpill(sorter.getMergeableIterator()); + spillMerger.addSpill(sorter.getSortedIterator()); return spillMerger.getSortedIterator(); } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java index 290a87b70cad6..1e0a5c7bd1113 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.unsafe.sort; -import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordPointerAndKeyPrefix; import org.apache.spark.util.collection.SortDataFormat; +import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.RecordPointerAndKeyPrefix; /** * Supports sorting an array of (record pointer, key prefix) pairs. Used in {@link UnsafeSorter}. @@ -28,6 +28,19 @@ */ final class UnsafeSortDataFormat extends SortDataFormat { + static final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; + } + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); private UnsafeSortDataFormat() { } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index c69d486705f65..cfb85ea55bcd6 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -18,10 +18,10 @@ package org.apache.spark.unsafe.sort; import java.util.Comparator; -import java.util.Iterator; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.RecordPointerAndKeyPrefix; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records @@ -32,45 +32,6 @@ */ public final class UnsafeSorter { - public static final class RecordPointerAndKeyPrefix { - /** - * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a - * description of how these addresses are encoded. - */ - public long recordPointer; - - /** - * A key prefix, for use in comparisons. - */ - public long keyPrefix; - } - - /** - * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte - * prefix, this may simply return 0. - */ - public static abstract class RecordComparator { - /** - * Compare two records for order. - * - * @return a negative integer, zero, or a positive integer as the first record is less than, - * equal to, or greater than the second. - */ - public abstract int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset); - } - - /** - * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific - * comparisons, such as lexicographic comparison for strings. - */ - public static abstract class PrefixComparator { - public abstract int compare(long prefix1, long prefix2); - } - private final TaskMemoryManager memoryManager; private final Sorter sorter; private final Comparator sortComparator; @@ -148,40 +109,15 @@ public void insertRecord(long objectAddress, long keyPrefix) { * Return an iterator over record pointers in sorted order. For efficiency, all calls to * {@code next()} will return the same mutable object. */ - public Iterator getSortedIterator() { - sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new Iterator() { - private int position = 0; - private final RecordPointerAndKeyPrefix keyPointerAndPrefix = new RecordPointerAndKeyPrefix(); - - @Override - public boolean hasNext() { - return position < sortBufferInsertPosition; - } - - @Override - public RecordPointerAndKeyPrefix next() { - keyPointerAndPrefix.recordPointer = sortBuffer[position]; - keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1]; - position += 2; - return keyPointerAndPrefix; - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; - } - - public UnsafeSorterSpillMerger.MergeableIterator getMergeableIterator() { + public UnsafeSorterIterator getSortedIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new UnsafeSorterSpillMerger.MergeableIterator() { + return new UnsafeSorterIterator() { private int position = 0; private Object baseObject; private long baseOffset; private long keyPrefix; + private int recordLength; @Override public boolean hasNext() { @@ -189,28 +125,25 @@ public boolean hasNext() { } @Override - public void loadNextRecord() { + public void loadNext() { final long recordPointer = sortBuffer[position]; - keyPrefix = sortBuffer[position + 1]; - position += 2; baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer); + keyPrefix = sortBuffer[position + 1]; + position += 2; } @Override - public long getPrefix() { - return keyPrefix; - } + public Object getBaseObject() { return baseObject; } @Override - public Object getBaseObject() { - return baseObject; - } + public long getBaseOffset() { return baseOffset; } @Override - public long getBaseOffset() { - return baseOffset; - } + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } }; } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java new file mode 100644 index 0000000000000..8cac5887f9c3f --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import java.io.IOException; + +public abstract class UnsafeSorterIterator { + + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java index bd3f4424724f6..0837843ae442b 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java @@ -17,36 +17,23 @@ package org.apache.spark.unsafe.sort; +import java.io.IOException; import java.util.Comparator; import java.util.PriorityQueue; -import static org.apache.spark.unsafe.sort.UnsafeSorter.*; - final class UnsafeSorterSpillMerger { - private final PriorityQueue priorityQueue; - - public static abstract class MergeableIterator { - public abstract boolean hasNext(); - - public abstract void loadNextRecord(); - - public abstract long getPrefix(); - - public abstract Object getBaseObject(); - - public abstract long getBaseOffset(); - } + private final PriorityQueue priorityQueue; public UnsafeSorterSpillMerger( final RecordComparator recordComparator, - final UnsafeSorter.PrefixComparator prefixComparator) { - final Comparator comparator = new Comparator() { + final PrefixComparator prefixComparator) { + final Comparator comparator = new Comparator() { @Override - public int compare(MergeableIterator left, MergeableIterator right) { + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { final int prefixComparisonResult = - prefixComparator.compare(left.getPrefix(), right.getPrefix()); + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); if (prefixComparisonResult == 0) { return recordComparator.compare( left.getBaseObject(), left.getBaseOffset(), @@ -56,20 +43,21 @@ public int compare(MergeableIterator left, MergeableIterator right) { } } }; - priorityQueue = new PriorityQueue(10, comparator); + // TODO: the size is often known; incorporate size hints here. + priorityQueue = new PriorityQueue(10, comparator); } - public void addSpill(MergeableIterator spillReader) { + public void addSpill(UnsafeSorterIterator spillReader) throws IOException { if (spillReader.hasNext()) { - spillReader.loadNextRecord(); + spillReader.loadNext(); } priorityQueue.add(spillReader); } - public ExternalSorterIterator getSortedIterator() { - return new ExternalSorterIterator() { + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { - private MergeableIterator spillReader; + private UnsafeSorterIterator spillReader; @Override public boolean hasNext() { @@ -77,18 +65,27 @@ public boolean hasNext() { } @Override - public void loadNext() { + public void loadNext() throws IOException { if (spillReader != null) { if (spillReader.hasNext()) { - spillReader.loadNextRecord(); + spillReader.loadNext(); priorityQueue.add(spillReader); } } spillReader = priorityQueue.remove(); - baseObject = spillReader.getBaseObject(); - baseOffset = spillReader.getBaseOffset(); - keyPrefix = spillReader.getPrefix(); } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } }; } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java index 3102e5ab3b6f4..696bcd468b0a1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -25,16 +25,17 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.PlatformDependent; -final class UnsafeSorterSpillReader extends UnsafeSorterSpillMerger.MergeableIterator { +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { private final File file; private InputStream in; private DataInputStream din; - private long keyPrefix; private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? - private final Object baseObject = arr; private int nextRecordLength; + + private long keyPrefix; + private final Object baseObject = arr; private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( @@ -55,26 +56,17 @@ public boolean hasNext() { } @Override - public void loadNextRecord() { - try { - keyPrefix = din.readLong(); - ByteStreams.readFully(in, arr, 0, nextRecordLength); - nextRecordLength = din.readInt(); - if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { - in.close(); - in = null; - din = null; - } - } catch (Exception e) { - PlatformDependent.throwException(e); + public void loadNext() throws IOException { + keyPrefix = din.readLong(); + ByteStreams.readFully(in, arr, 0, nextRecordLength); + nextRecordLength = din.readInt(); + if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { + in.close(); + in = null; + din = null; } } - @Override - public long getPrefix() { - return keyPrefix; - } - @Override public Object getBaseObject() { return baseObject; @@ -84,4 +76,14 @@ public Object getBaseObject() { public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { + return 0; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } } diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java index e745074af075c..c33035c2c116b 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -21,8 +21,8 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; -import org.apache.spark.TaskContextImpl; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; @@ -43,7 +43,6 @@ import java.io.File; import java.io.InputStream; import java.io.OutputStream; -import java.util.Iterator; import java.util.UUID; import static org.mockito.Mockito.*; @@ -56,7 +55,7 @@ public class UnsafeExternalSorterSuite { // Compute key prefixes based on the records' partition ids final HashPartitioner hashPartitioner = new HashPartitioner(4); // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + final PrefixComparator prefixComparator = new PrefixComparator() { @Override public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; @@ -64,7 +63,7 @@ public int compare(long prefix1, long prefix2) { }; // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so // use a dummy comparator - final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + final RecordComparator recordComparator = new RecordComparator() { @Override public int compare( Object leftBaseObject, @@ -95,6 +94,7 @@ public void setUp() { blockManager = mock(BlockManager.class); tempDir = new File(Utils.createTempDir$default$1()); taskContext = mock(TaskContext.class); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { @@ -158,18 +158,18 @@ public void testSortingOnlyByPartitionId() throws Exception { insertNumber(sorter, 4); insertNumber(sorter, 2); - ExternalSorterIterator iter = sorter.getSortedIterator(); + UnsafeSorterIterator iter = sorter.getSortedIterator(); iter.loadNext(); - Assert.assertEquals(1, iter.keyPrefix); + Assert.assertEquals(1, iter.getKeyPrefix()); iter.loadNext(); - Assert.assertEquals(2, iter.keyPrefix); + Assert.assertEquals(2, iter.getKeyPrefix()); iter.loadNext(); - Assert.assertEquals(3, iter.keyPrefix); + Assert.assertEquals(3, iter.getKeyPrefix()); iter.loadNext(); - Assert.assertEquals(4, iter.keyPrefix); + Assert.assertEquals(4, iter.getKeyPrefix()); iter.loadNext(); - Assert.assertEquals(5, iter.keyPrefix); + Assert.assertEquals(5, iter.getKeyPrefix()); Assert.assertFalse(iter.hasNext()); // TODO: check that the values are also read back properly. diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index aed115f83a368..3a2e9696761de 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.sort; import java.util.Arrays; -import java.util.Iterator; import org.junit.Assert; import org.junit.Test; @@ -48,10 +47,10 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset) public void testSortingEmptyInput() { final UnsafeSorter sorter = new UnsafeSorter( new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), - mock(UnsafeSorter.RecordComparator.class), - mock(UnsafeSorter.PrefixComparator.class), + mock(RecordComparator.class), + mock(PrefixComparator.class), 100); - final Iterator iter = sorter.getSortedIterator(); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -91,7 +90,7 @@ public void testSortingOnlyByPartitionId() throws Exception { } // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so // use a dummy comparator - final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + final RecordComparator recordComparator = new RecordComparator() { @Override public int compare( Object leftBaseObject, @@ -104,7 +103,7 @@ public int compare( // Compute key prefixes based on the records' partition ids final HashPartitioner hashPartitioner = new HashPartitioner(4); // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + final PrefixComparator prefixComparator = new PrefixComparator() { @Override public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; @@ -123,19 +122,18 @@ public int compare(long prefix1, long prefix2) { sorter.insertRecord(address, partitionId); position += 8 + recordLength; } - final Iterator iter = sorter.getSortedIterator(); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; Arrays.sort(dataToSort); while (iter.hasNext()) { - final UnsafeSorter.RecordPointerAndKeyPrefix pointerAndPrefix = iter.next(); - final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); - final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); - final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); + iter.loadNext(); + final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset()); + final long keyPrefix = iter.getKeyPrefix(); Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); - Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " + - prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix); - prevPrefix = pointerAndPrefix.keyPrefix; + Assert.assertTrue("Prefix " + keyPrefix + " should be >= previous prefix " + + prevPrefix, keyPrefix >= prevPrefix); + prevPrefix = keyPrefix; iterLength++; } Assert.assertEquals(dataToSort.length, iterLength);