From abf7bfe4ddbb2603272ef3926776ceefcc07ff7f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 29 Apr 2015 14:34:15 -0700 Subject: [PATCH] Add basic test case. --- .../unsafe/sort/UnsafeSortDataFormat.java | 18 +-- .../spark/unsafe/sort/UnsafeSorter.java | 33 +++-- .../spark/unsafe/sort/UnsafeSorterSuite.java | 136 +++++++++++++++++- 3 files changed, 157 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java index 6bae742e2bdab..9955e3fcaabbb 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe.sort; +import static org.apache.spark.unsafe.sort.UnsafeSorter.KeyPointerAndPrefix; import org.apache.spark.util.collection.SortDataFormat; /** @@ -26,24 +27,11 @@ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ final class UnsafeSortDataFormat - extends SortDataFormat { + extends SortDataFormat { public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); - private UnsafeSortDataFormat() { }; - - public static final class KeyPointerAndPrefix { - /** - * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a - * description of how these addresses are encoded. - */ - long recordPointer; - - /** - * A key prefix, for use in comparisons. - */ - long keyPrefix; - } + private UnsafeSortDataFormat() { } @Override public KeyPointerAndPrefix getKey(long[] data, int pos) { diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 9e8a4d707d181..6da89004d2f53 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -20,13 +20,24 @@ import java.util.Comparator; import java.util.Iterator; -import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.KeyPointerAndPrefix; public final class UnsafeSorter { + public static final class KeyPointerAndPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + long keyPrefix; + } + public static abstract class RecordComparator { public abstract int compare( Object leftBaseObject, @@ -105,11 +116,11 @@ public void insertRecord(long objectAddress) { sortBufferInsertPosition += 2; } - public Iterator getSortedIterator() { - final MemoryLocation memoryLocation = new MemoryLocation(); + public Iterator getSortedIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator); - return new Iterator() { - int position = 0; + return new Iterator() { + private int position = 0; + private final KeyPointerAndPrefix keyPointerAndPrefix = new KeyPointerAndPrefix(); @Override public boolean hasNext() { @@ -117,13 +128,11 @@ public boolean hasNext() { } @Override - public MemoryLocation next() { - final long address = sortBuffer[position]; + public KeyPointerAndPrefix next() { + keyPointerAndPrefix.recordPointer = sortBuffer[position]; + keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1]; position += 2; - final Object baseObject = memoryManager.getPage(address); - final long baseOffset = memoryManager.getOffsetInPage(address); - memoryLocation.setObjAndOffset(baseObject, baseOffset); - return memoryLocation; + return keyPointerAndPrefix; } @Override diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index f96c8ebd723c9..c22edfb412e1b 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -1,7 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.unsafe.sort; -/** - * Created by joshrosen on 4/29/15. - */ +import java.util.Arrays; +import java.util.Iterator; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + public class UnsafeSorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset) { + final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + final byte[] strBytes = new byte[strLength]; + PlatformDependent.UNSAFE.copyMemory( + baseObject, + baseOffset + 8, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + /** + * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. + */ + @Test + public void testSortingOnlyByPartitionId() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); + position += 8; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + final UnsafeSorter.PrefixComputer prefixComputer = new UnsafeSorter.PrefixComputer() { + @Override + public long computePrefix(Object baseObject, long baseOffset) { + final String str = getStringFromDataPage(baseObject, baseOffset); + final int partitionId = hashPartitioner.getPartition(str); + return (long) partitionId; + } + }; + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + final UnsafeSorter sorter = + new UnsafeSorter(memoryManager, recordComparator, prefixComputer, prefixComparator); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + sorter.insertRecord(address); + position += 8 + recordLength; + } + final Iterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + final UnsafeSorter.KeyPointerAndPrefix pointerAndPrefix = iter.next(); + final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); + final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); + final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); + Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); + Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " + + prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix); + prevPrefix = pointerAndPrefix.keyPrefix; + iterLength++; + } + Assert.assertEquals(dataToSort.length, iterLength); + } }