Skip to content

Commit

Permalink
Add basic test case.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 1, 2015
1 parent 81d52c5 commit abf7bfe
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.unsafe.sort;

import static org.apache.spark.unsafe.sort.UnsafeSorter.KeyPointerAndPrefix;
import org.apache.spark.util.collection.SortDataFormat;

/**
Expand All @@ -26,24 +27,11 @@
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
final class UnsafeSortDataFormat
extends SortDataFormat<UnsafeSortDataFormat.KeyPointerAndPrefix, long[]> {
extends SortDataFormat<KeyPointerAndPrefix, long[]> {

public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();

private UnsafeSortDataFormat() { };

public static final class KeyPointerAndPrefix {
/**
* A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
* description of how these addresses are encoded.
*/
long recordPointer;

/**
* A key prefix, for use in comparisons.
*/
long keyPrefix;
}
private UnsafeSortDataFormat() { }

@Override
public KeyPointerAndPrefix getKey(long[] data, int pos) {
Expand Down
33 changes: 21 additions & 12 deletions core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,24 @@
import java.util.Comparator;
import java.util.Iterator;

import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.util.collection.Sorter;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.KeyPointerAndPrefix;

public final class UnsafeSorter {

public static final class KeyPointerAndPrefix {
/**
* A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
* description of how these addresses are encoded.
*/
long recordPointer;

/**
* A key prefix, for use in comparisons.
*/
long keyPrefix;
}

public static abstract class RecordComparator {
public abstract int compare(
Object leftBaseObject,
Expand Down Expand Up @@ -105,25 +116,23 @@ public void insertRecord(long objectAddress) {
sortBufferInsertPosition += 2;
}

public Iterator<MemoryLocation> getSortedIterator() {
final MemoryLocation memoryLocation = new MemoryLocation();
public Iterator<KeyPointerAndPrefix> getSortedIterator() {
sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator);
return new Iterator<MemoryLocation>() {
int position = 0;
return new Iterator<KeyPointerAndPrefix>() {
private int position = 0;
private final KeyPointerAndPrefix keyPointerAndPrefix = new KeyPointerAndPrefix();

@Override
public boolean hasNext() {
return position < sortBufferInsertPosition;
}

@Override
public MemoryLocation next() {
final long address = sortBuffer[position];
public KeyPointerAndPrefix next() {
keyPointerAndPrefix.recordPointer = sortBuffer[position];
keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1];
position += 2;
final Object baseObject = memoryManager.getPage(address);
final long baseOffset = memoryManager.getOffsetInPage(address);
memoryLocation.setObjAndOffset(baseObject, baseOffset);
return memoryLocation;
return keyPointerAndPrefix;
}

@Override
Expand Down
136 changes: 133 additions & 3 deletions core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.unsafe.sort;

/**
* Created by joshrosen on 4/29/15.
*/
import java.util.Arrays;
import java.util.Iterator;

import org.junit.Assert;
import org.junit.Test;

import org.apache.spark.HashPartitioner;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;

public class UnsafeSorterSuite {

private static String getStringFromDataPage(Object baseObject, long baseOffset) {
final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
final byte[] strBytes = new byte[strLength];
PlatformDependent.UNSAFE.copyMemory(
baseObject,
baseOffset + 8,
strBytes,
PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
return new String(strBytes);
}

/**
* Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
*/
@Test
public void testSortingOnlyByPartitionId() throws Exception {
final String[] dataToSort = new String[] {
"Boba",
"Pearls",
"Tapioca",
"Taho",
"Condensed Milk",
"Jasmine",
"Milk Tea",
"Lychee",
"Mango"
};
final TaskMemoryManager memoryManager =
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
final Object baseObject = dataPage.getBaseObject();
// Write the records into the data page:
long position = dataPage.getBaseOffset();
for (String str : dataToSort) {
final byte[] strBytes = str.getBytes("utf-8");
PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length);
position += 8;
PlatformDependent.copyMemory(
strBytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
position,
strBytes.length);
position += strBytes.length;
}
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
// use a dummy comparator
final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() {
@Override
public int compare(
Object leftBaseObject,
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset) {
return 0;
}
};
// Compute key prefixes based on the records' partition ids
final HashPartitioner hashPartitioner = new HashPartitioner(4);
final UnsafeSorter.PrefixComputer prefixComputer = new UnsafeSorter.PrefixComputer() {
@Override
public long computePrefix(Object baseObject, long baseOffset) {
final String str = getStringFromDataPage(baseObject, baseOffset);
final int partitionId = hashPartitioner.getPartition(str);
return (long) partitionId;
}
};
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() {
@Override
public int compare(long prefix1, long prefix2) {
return (int) prefix1 - (int) prefix2;
}
};
final UnsafeSorter sorter =
new UnsafeSorter(memoryManager, recordComparator, prefixComputer, prefixComparator);
// Given a page of records, insert those records into the sorter one-by-one:
position = dataPage.getBaseOffset();
for (int i = 0; i < dataToSort.length; i++) {
// position now points to the start of a record (which holds its length).
final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position);
final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
sorter.insertRecord(address);
position += 8 + recordLength;
}
final Iterator<UnsafeSorter.KeyPointerAndPrefix> iter = sorter.getSortedIterator();
int iterLength = 0;
long prevPrefix = -1;
Arrays.sort(dataToSort);
while (iter.hasNext()) {
final UnsafeSorter.KeyPointerAndPrefix pointerAndPrefix = iter.next();
final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer);
final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer);
final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset);
Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1);
Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " +
prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix);
prevPrefix = pointerAndPrefix.keyPrefix;
iterLength++;
}
Assert.assertEquals(dataToSort.length, iterLength);
}
}

0 comments on commit abf7bfe

Please sign in to comment.