diff --git a/core/pom.xml b/core/pom.xml index aee0d92620606..558cc3fb9f2f3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -343,28 +343,28 @@ test - org.mockito - mockito-core + org.hamcrest + hamcrest-core test - org.scalacheck - scalacheck_${scala.binary.version} + org.hamcrest + hamcrest-library test - junit - junit + org.mockito + mockito-core test - org.hamcrest - hamcrest-core + org.scalacheck + scalacheck_${scala.binary.version} test - org.hamcrest - hamcrest-library + junit + junit test diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java similarity index 91% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java rename to core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 3f746b886bc9b..0399abc63c235 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.serializer; import java.io.IOException; import java.io.InputStream; @@ -24,9 +24,7 @@ import scala.reflect.ClassTag; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.PlatformDependent; /** @@ -35,7 +33,8 @@ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work * around this, we pass a dummy no-op serializer. */ -final class DummySerializerInstance extends SerializerInstance { +@Private +public final class DummySerializerInstance extends SerializerInstance { public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 9e9ed94b7890c..56289573209fb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -30,6 +30,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java new file mode 100644 index 0000000000000..45b78829e4cf7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.annotation.Private; + +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +@Private +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java new file mode 100644 index 0000000000000..438742565c51d --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.base.Charsets; +import com.google.common.primitives.Longs; +import com.google.common.primitives.UnsignedBytes; + +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.types.UTF8String; + +@Private +public class PrefixComparators { + private PrefixComparators() {} + + public static final StringPrefixComparator STRING = new StringPrefixComparator(); + public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); + public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + + public static final class StringPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + // TODO: can done more efficiently + byte[] a = Longs.toByteArray(aPrefix); + byte[] b = Longs.toByteArray(bPrefix); + for (int i = 0; i < 8; i++) { + int c = UnsignedBytes.compare(a[i], b[i]); + if (c != 0) return c; + } + return 0; + } + + public long computePrefix(byte[] bytes) { + if (bytes == null) { + return 0L; + } else { + byte[] padded = new byte[8]; + System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8)); + return Longs.fromByteArray(padded); + } + } + + public long computePrefix(String value) { + return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8)); + } + + public long computePrefix(UTF8String value) { + return value == null ? 0L : computePrefix(value.getBytes()); + } + } + + /** + * Prefix comparator for all integral types (boolean, byte, short, int, long). + */ + public static final class IntegralPrefixComparator extends PrefixComparator { + @Override + public int compare(long a, long b) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public final long NULL_PREFIX = Long.MIN_VALUE; + } + + public static final class FloatPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + float a = Float.intBitsToFloat((int) aPrefix); + float b = Float.intBitsToFloat((int) bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(float value) { + return Float.floatToIntBits(value) & 0xffffffffL; + } + + public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); + } + + public static final class DoublePrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + + public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java new file mode 100644 index 0000000000000..09e4258792204 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { + + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java new file mode 100644 index 0000000000000..0c4ebde407cfc --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 0000000000000..4d6731ee60af3 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.LinkedList; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * External sorter based on {@link UnsafeInMemorySorter}. + */ +public final class UnsafeExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + + private static final int PAGE_SIZE = 1 << 27; // 128 megabytes + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; + + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + // These variables are reset after spilling: + private UnsafeInMemorySorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + private final LinkedList spillWriters = new LinkedList<>(); + + public UnsafeExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + SparkConf conf) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + initializeForWriting(); + } + + // TODO: metrics tracking + integration with shuffle write metrics + // need to connect the write metrics to task metrics so we count the spill IO somewhere. + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.sorter = + new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize); + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + public void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + sorter.numRecords()); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + final int recordLength = sortedRecords.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + initializeForWriting(); + } + + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + } + + public long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + // TODO: merge these steps to first calculate total memory requirements for this insert, + // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the + // data page. + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + } + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws IOException { + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + + sorter.insertRecord(recordAddress, prefix); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); + int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); + if (spillWriters.isEmpty()) { + return inMemoryIterator; + } else { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpill(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + if (inMemoryIterator.hasNext()) { + spillMerger.addSpill(inMemoryIterator); + } + return spillMerger.getSortedIterator(); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java new file mode 100644 index 0000000000000..fc34ad9cff369 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Comparator; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemorySorter { + + private static final class SortComparator implements Comparator { + + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + + SortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + } + + @Override + public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { + final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix); + if (prefixComparisonResult == 0) { + final Object baseObject1 = memoryManager.getPage(r1.recordPointer); + final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length + final Object baseObject2 = memoryManager.getPage(r2.recordPointer); + final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length + return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + } else { + return prefixComparisonResult; + } + } + } + + private final TaskMemoryManager memoryManager; + private final Sorter sorter; + private final Comparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ + private long[] pointerArray; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize * 2]; + this.memoryManager = memoryManager; + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } + + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pointerArrayInsertPosition / 2; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 2 < pointerArray.length; + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + /** + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a 4-byte integer, followed by the record's bytes. + * + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix a user-defined key prefix + */ + public void insertRecord(long recordPointer, long keyPrefix) { + if (!hasSpaceForAnotherRecord()) { + expandPointerArray(); + } + pointerArray[pointerArrayInsertPosition] = recordPointer; + pointerArrayInsertPosition++; + pointerArray[pointerArrayInsertPosition] = keyPrefix; + pointerArrayInsertPosition++; + } + + private static final class SortedIterator extends UnsafeSorterIterator { + + private final TaskMemoryManager memoryManager; + private final int sortBufferInsertPosition; + private final long[] sortBuffer; + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + + SortedIterator( + TaskMemoryManager memoryManager, + int sortBufferInsertPosition, + long[] sortBuffer) { + this.memoryManager = memoryManager; + this.sortBufferInsertPosition = sortBufferInsertPosition; + this.sortBuffer = sortBuffer; + } + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void loadNext() { + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length + recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public UnsafeSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator); + return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java new file mode 100644 index 0000000000000..d09c728a7a638 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.util.collection.SortDataFormat; + +/** + * Supports sorting an array of (record pointer, key prefix) pairs. + * Used in {@link UnsafeInMemorySorter}. + *

+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ +final class UnsafeSortDataFormat extends SortDataFormat { + + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + + private UnsafeSortDataFormat() { } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); + } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data[pos * 2]; + reuse.keyPrefix = data[pos * 2 + 1]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + long tempPointer = data[pos0 * 2]; + long tempKeyPrefix = data[pos0 * 2 + 1]; + data[pos0 * 2] = data[pos1 * 2]; + data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; + data[pos1 * 2] = tempPointer; + data[pos1 * 2 + 1] = tempKeyPrefix; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos * 2] = src[srcPos * 2]; + dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; + return new long[length * 2]; + } + +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java new file mode 100644 index 0000000000000..16ac2e8d821ba --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; + +public abstract class UnsafeSorterIterator { + + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java new file mode 100644 index 0000000000000..8272c2a5be0d1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.Comparator; +import java.util.PriorityQueue; + +final class UnsafeSorterSpillMerger { + + private final PriorityQueue priorityQueue; + + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + final int numSpills) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + priorityQueue = new PriorityQueue(numSpills, comparator); + } + + public void addSpill(UnsafeSorterIterator spillReader) throws IOException { + if (spillReader.hasNext()) { + spillReader.loadNext(); + } + priorityQueue.add(spillReader); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { + + private UnsafeSorterIterator spillReader; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + } + + @Override + public void loadNext() throws IOException { + if (spillReader != null) { + if (spillReader.hasNext()) { + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.remove(); + } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } + }; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 0000000000000..29e9e0f30f934 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.*; + +import com.google.common.io.ByteStreams; + +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description + * of the file format). + */ +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + + private InputStream in; + private DataInputStream din; + + // Variables that change with every record read: + private int recordLength; + private long keyPrefix; + private int numRecordsRemaining; + + private byte[] arr = new byte[1024 * 1024]; + private Object baseObject = arr; + private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + assert (file.length() > 0); + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public void loadNext() throws IOException { + recordLength = din.readInt(); + keyPrefix = din.readLong(); + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } + ByteStreams.readFully(in, arr, 0, recordLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { + in.close(); + in = null; + din = null; + } + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 0000000000000..b8d66659804ad --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Tuple2; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Spills a list of sorted records to disk. Spill files have the following format: + * + * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] + */ +final class UnsafeSorterSpillWriter { + + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private final int numRecordsToWrite; + private BlockObjectWriter writer; + private int numRecordsSpilled = 0; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + // Based on DataOutputStream.writeLong. + private void writeLongToBuffer(long v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 56); + writeBuffer[offset + 1] = (byte)(v >>> 48); + writeBuffer[offset + 2] = (byte)(v >>> 40); + writeBuffer[offset + 3] = (byte)(v >>> 32); + writeBuffer[offset + 4] = (byte)(v >>> 24); + writeBuffer[offset + 5] = (byte)(v >>> 16); + writeBuffer[offset + 6] = (byte)(v >>> 8); + writeBuffer[offset + 7] = (byte)(v >>> 0); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); + } + + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + * @param keyPrefix a sort key prefix + */ + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } + writeIntToBuffer(recordLength, 0); + writeLongToBuffer(keyPrefix, 4); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + long recordReadPosition = baseOffset; + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } + writer.recordWritten(); + } + + public void close() throws IOException { + writer.commitAndClose(); + writer = null; + writeBuffer = null; + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java new file mode 100644 index 0000000000000..ea8755e21eb68 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.UUID; + +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +public class UnsafeExternalSorterSuite { + + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + + File tempDir; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + tempDir = new File(Utils.createTempDir$default$1()); + taskContext = mock(TaskContext.class); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { + final int[] arr = new int[] { value }; + sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + } + + @Test + public void testSortingOnlyByPrefix() throws Exception { + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + insertNumber(sorter, 5); + insertNumber(sorter, 1); + insertNumber(sorter, 3); + sorter.spill(); + insertNumber(sorter, 4); + sorter.spill(); + insertNumber(sorter, 2); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(i, iter.getKeyPrefix()); + assertEquals(4, iter.getRecordLength()); + // TODO: read rest of value. + } + + // TODO: test for cleanup: + // assert(tempDir.isEmpty) + } + + @Test + public void testSortingEmptyArrays() throws Exception { + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(0, iter.getKeyPrefix()); + assertEquals(0, iter.getRecordLength()); + } + } + +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java new file mode 100644 index 0000000000000..909500930539c --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Arrays; + +import org.junit.Test; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class UnsafeInMemorySorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { + final byte[] strBytes = new byte[length]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, length); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + mock(RecordComparator.class), + mock(PrefixComparator.class), + 100); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + @Test + public void testSortingOnlyByIntegerPrefix() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + position += 4; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + prefixComparator, dataToSort.length); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + final String str = getStringFromDataPage(baseObject, position + 4, recordLength); + final int partitionId = hashPartitioner.getPartition(str); + sorter.insertRecord(address, partitionId); + position += 4 + recordLength; + } + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + iter.loadNext(); + final String str = + getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength()); + final long keyPrefix = iter.getKeyPrefix(); + assertThat(str, isIn(Arrays.asList(dataToSort))); + assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix)); + prevPrefix = keyPrefix; + iterLength++; + } + assertEquals(dataToSort.length, iterLength); + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala new file mode 100644 index 0000000000000..dd505dfa7d758 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort + +import org.scalatest.prop.PropertyChecks + +import org.apache.spark.SparkFunSuite + +class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { + + test("String prefix comparator") { + + def testPrefixComparison(s1: String, s2: String): Unit = { + val s1Prefix = PrefixComparators.STRING.computePrefix(s1) + val s2Prefix = PrefixComparators.STRING.computePrefix(s2) + val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + assert( + (prefixComparisonResult == 0) || + (prefixComparisonResult < 0 && s1 < s2) || + (prefixComparisonResult > 0 && s1 > s2)) + } + + // scalastyle:off + val regressionTests = Table( + ("s1", "s2"), + ("abc", "世界"), + ("你好", "世界"), + ("你好123", "你好122") + ) + // scalastyle:on + + forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index edb7202245289..4b99030d1046f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -61,9 +61,10 @@ public final class UnsafeRow extends MutableRow { /** A pool to hold non-primitive objects */ private ObjectPool pool; - Object getBaseObject() { return baseObject; } - long getBaseOffset() { return baseOffset; } - ObjectPool getPool() { return pool; } + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + public ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java new file mode 100644 index 0000000000000..b94601cf6d818 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; + +import scala.collection.Iterator; +import scala.math.Ordering; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.sql.AbstractScalaRowIterator; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter; +import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; +import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; + +final class UnsafeExternalRowSorter { + + /** + * If positive, forces records to be spilled to disk at the given frequency (measured in numbers + * of records). This is only intended to be used in tests. + */ + private int testSpillFrequency = 0; + + private long numRowsInserted = 0; + + private final StructType schema; + private final UnsafeRowConverter rowConverter; + private final PrefixComputer prefixComputer; + private final UnsafeExternalSorter sorter; + private byte[] rowConversionBuffer = new byte[1024 * 8]; + + public static abstract class PrefixComputer { + abstract long computePrefix(InternalRow row); + } + + public UnsafeExternalRowSorter( + StructType schema, + Ordering ordering, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer) throws IOException { + this.schema = schema; + this.rowConverter = new UnsafeRowConverter(schema); + this.prefixComputer = prefixComputer; + final SparkEnv sparkEnv = SparkEnv.get(); + final TaskContext taskContext = TaskContext.get(); + sorter = new UnsafeExternalSorter( + taskContext.taskMemoryManager(), + sparkEnv.shuffleMemoryManager(), + sparkEnv.blockManager(), + taskContext, + new RowComparator(ordering, schema.length(), null), + prefixComparator, + 4096, + sparkEnv.conf() + ); + } + + /** + * Forces spills to occur every `frequency` records. Only for use in tests. + */ + @VisibleForTesting + void setTestSpillFrequency(int frequency) { + assert frequency > 0 : "Frequency must be positive"; + testSpillFrequency = frequency; + } + + @VisibleForTesting + void insertRow(InternalRow row) throws IOException { + final int sizeRequirement = rowConverter.getSizeRequirement(row); + if (sizeRequirement > rowConversionBuffer.length) { + rowConversionBuffer = new byte[sizeRequirement]; + } + final int bytesWritten = rowConverter.writeRow( + row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null); + assert (bytesWritten == sizeRequirement); + final long prefix = prefixComputer.computePrefix(row); + sorter.insertRecord( + rowConversionBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeRequirement, + prefix + ); + numRowsInserted++; + if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { + spill(); + } + } + + @VisibleForTesting + void spill() throws IOException { + sorter.spill(); + } + + private void cleanupResources() { + sorter.freeMemory(); + } + + @VisibleForTesting + Iterator sort() throws IOException { + try { + final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); + if (!sortedIterator.hasNext()) { + // Since we won't ever call next() on an empty iterator, we need to clean up resources + // here in order to prevent memory leaks. + cleanupResources(); + } + return new AbstractScalaRowIterator() { + + private final int numFields = schema.length(); + private final UnsafeRow row = new UnsafeRow(); + + @Override + public boolean hasNext() { + return sortedIterator.hasNext(); + } + + @Override + public InternalRow next() { + try { + sortedIterator.loadNext(); + row.pointTo( + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + numFields, + sortedIterator.getRecordLength(), + null); + if (!hasNext()) { + row.copy(); // so that we don't have dangling pointers to freed page + cleanupResources(); + } + return row; + } catch (IOException e) { + cleanupResources(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception: + PlatformDependent.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + }; + }; + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + + + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); + } + + /** + * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. + */ + public static boolean supportsSchema(StructType schema) { + // TODO: add spilling note to explain why we do this for now: + for (StructField field : schema.fields()) { + if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { + return false; + } + } + return true; + } + + private static final class RowComparator extends RecordComparator { + private final Ordering ordering; + private final int numFields; + private final ObjectPool objPool; + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + + public RowComparator(Ordering ordering, int numFields, ObjectPool objPool) { + this.numFields = numFields; + this.ordering = ordering; + this.objPool = objPool; + } + + @Override + public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool); + row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool); + return ordering.compare(row1, row2); + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala new file mode 100644 index 0000000000000..cfefb13e7721e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator + * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to + * `Row` in order to work around a spurious IntelliJ compiler error. + */ +private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 74d933404551c..4b783e30d95e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -289,11 +289,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } val withSort = if (needSort) { - if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, withShuffle) - } else { - Sort(rowOrdering, global = false, withShuffle) - } + sqlContext.planner.BasicOperators.getSortOperator( + rowOrdering, global = false, withShuffle) } else { withShuffle } @@ -321,11 +318,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (UnspecifiedDistribution, Seq(), child) => child case (UnspecifiedDistribution, rowOrdering, child) => - if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, child) - } else { - Sort(rowOrdering, global = false, child) - } + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) case (dist, ordering, _) => sys.error(s"Don't know how to ensure $dist with ordering $ordering") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala new file mode 100644 index 0000000000000..2dee3542d6101 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} + + +object SortPrefixUtils { + + /** + * A dummy prefix comparator which always claims that prefixes are equal. This is used in cases + * where we don't know how to generate or compare prefixes for a SortOrder. + */ + private object NoOpPrefixComparator extends PrefixComparator { + override def compare(prefix1: Long, prefix2: Long): Int = 0 + } + + def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { + sortOrder.dataType match { + case StringType => PrefixComparators.STRING + case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL + case FloatType => PrefixComparators.FLOAT + case DoubleType => PrefixComparators.DOUBLE + case _ => NoOpPrefixComparator + } + } + + def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { + sortOrder.dataType match { + case StringType => (row: InternalRow) => { + PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) + } + case BooleanType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 + else 0 + } + case ByteType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Byte] + } + case ShortType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Short] + } + case IntegerType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Int] + } + case LongType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Long] + } + case FloatType => (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX + else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) + } + case DoubleType => (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX + else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) + } + case _ => (row: InternalRow) => 0L + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 59b9b553a7ae5..ce25af58b6cab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -302,6 +302,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions + /** + * Picks an appropriate sort operator. + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ + def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { + if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) { + execution.UnsafeExternalSort(sortExprs, global, child) + } else if (sqlContext.conf.externalSortEnabled) { + execution.ExternalSort(sortExprs, global, child) + } else { + execution.Sort(sortExprs, global, child) + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil @@ -313,11 +329,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. - execution.Sort(sortExprs, global = false, planLater(child)) :: Nil - case logical.Sort(sortExprs, global, child) if sqlContext.conf.externalSortEnabled => - execution.ExternalSort(sortExprs, global, planLater(child)):: Nil + getSortOperator(sortExprs, global = false, planLater(child)) :: Nil case logical.Sort(sortExprs, global, child) => - execution.Sort(sortExprs, global, planLater(child)):: Nil + getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index de14e6ad79ad6..4c063c299ba53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.types.StructType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager @@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.{HashPartitioner, SparkEnv} @@ -246,6 +248,77 @@ case class ExternalSort( override def outputOrdering: Seq[SortOrder] = sortOrder } +/** + * :: DeveloperApi :: + * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Project Tungsten). + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will + * spill every `frequency` records. + */ +@DeveloperApi +case class UnsafeExternalSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryNode { + + private[this] val schema: StructType = child.schema + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") + def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val ordering = newOrdering(sortOrder, child.output) + val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) + // Hack until we generate separate comparator implementations for ascending vs. descending + // (or choose to codegen them): + val prefixComparator = { + val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) + if (sortOrder.head.direction == Descending) { + new PrefixComparator { + override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) + } + } else { + comp + } + } + val prefixComputer = { + val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) + new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = prefixComputer(row) + } + } + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter.sort(iterator) + } + child.execute().mapPartitions(doSort, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} + +@DeveloperApi +object UnsafeExternalSort { + /** + * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. + */ + def supportsSchema(schema: StructType): Boolean = { + UnsafeExternalRowSorter.supportsSchema(schema) + } +} + + /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a1e3ca11b1ad9..a2c10fdaf6cdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ class SortSuite extends SparkPlanTest { @@ -33,12 +34,14 @@ class SortSuite extends SparkPlanTest { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), - input.sorted) + ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), + sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), - input.sortBy(t => (t._2, t._1))) + ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), + sortAnswers = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 108b1122f7bff..6a8f394545816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -17,18 +17,15 @@ package org.apache.spark.sql.execution -import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal - import org.apache.spark.SparkFunSuite - import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util._ - import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame} +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row} + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal /** * Base class for writing tests for individual physical operators. For an example of how this @@ -49,12 +46,19 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer) + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + input :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans.head), + expectedAnswer, + sortAnswers) } /** @@ -64,86 +68,131 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ - protected def checkAnswer( + protected def checkAnswer2( left: DataFrame, right: DataFrame, planFunction: (SparkPlan, SparkPlan) => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - checkAnswer(left :: right :: Nil, - (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer) + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), + expectedAnswer, + sortAnswers) } /** * Runs the plan and makes sure the answer matches the expected result. * @param input the input data to be used. - * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate - * the physical operator that's being tested. + * @param planFunction a function which accepts a sequence of input SparkPlans and uses them to + * instantiate the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ - protected def checkAnswer( + protected def doCheckAnswer( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match { case Some(errorMessage) => fail(errorMessage) case None => } } /** - * Runs the plan and makes sure the answer matches the expected result. + * Runs the plan and makes sure the answer matches the result produced by a reference plan. * @param input the input data to be used. * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to + * instantiate a reference implementation of the physical operator + * that's being tested. The result of executing this plan will be + * treated as the source-of-truth for the test. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ - protected def checkAnswer[A <: Product : TypeTag]( + protected def checkThatPlansAgree( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[A]): Unit = { - val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(input, planFunction, expectedRows) + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean = true): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } } +} - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param left the left input data to be used. - * @param right the right input data to be used. - * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate - * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. - */ - protected def checkAnswer[A <: Product : TypeTag]( - left: DataFrame, - right: DataFrame, - planFunction: (SparkPlan, SparkPlan) => SparkPlan, - expectedAnswer: Seq[A]): Unit = { - val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(left, right, planFunction, expectedRows) - } +/** + * Helper methods for writing tests of individual physical operators. + */ +object SparkPlanTest { /** - * Runs the plan and makes sure the answer matches the expected result. + * Runs the plan and makes sure the answer matches the result produced by a reference plan. * @param input the input data to be used. * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to + * instantiate a reference implementation of the physical operator + * that's being tested. The result of executing this plan will be + * treated as the source-of-truth for the test. */ - protected def checkAnswer[A <: Product : TypeTag]( - input: Seq[DataFrame], - planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[A]): Unit = { - val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(input, planFunction, expectedRows) - } + def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean): Option[String] = { -} + val outputPlan = planFunction(input.queryExecution.sparkPlan) + val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) -/** - * Helper methods for writing tests of individual physical operators. - */ -object SparkPlanTest { + val expectedAnswer: Seq[Row] = try { + executePlan(expectedOutputPlan) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan to calculate expected answer: + | $expectedOutputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + val actualAnswer: Seq[Row] = try { + executePlan(outputPlan) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match. + | Actual result Spark plan: + | $outputPlan + | Expected result Spark plan: + | $expectedOutputPlan + | $errorMessage + """.stripMargin + } + } /** * Runs the plan and makes sure the answer matches the expected result. @@ -151,28 +200,45 @@ object SparkPlanTest { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ def checkAnswer( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[Row]): Option[String] = { + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = TestSQLContext.prepareForExecution.execute( - outputPlan transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan.transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - ) + val sparkAnswer: Seq[Row] = try { + executePlan(outputPlan) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match for Spark plan: + | $outputPlan + | $errorMessage + """.stripMargin + } + } + + private def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to @@ -187,40 +253,43 @@ object SparkPlanTest { case o => o }) } - converted.sortBy(_.toString()) - } - - val sparkAnswer: Seq[Row] = try { - resolvedPlan.executeCollect().toSeq - } catch { - case NonFatal(e) => - val errorMessage = - s""" - | Exception thrown while executing Spark plan: - | $outputPlan - | == Exception == - | $e - | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { val errorMessage = s""" - | Results do not match for Spark plan: - | $outputPlan | == Results == | ${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: + s"== Expected Answer - ${expectedAnswer.size} ==" +: prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: + s"== Actual Answer - ${sparkAnswer.size} ==" +: prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} """.stripMargin - return Some(errorMessage) + Some(errorMessage) + } else { + None } + } - None + private def executePlan(outputPlan: SparkPlan): Seq[Row] = { + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = TestSQLContext.prepareForExecution.execute( + outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + ) + resolvedPlan.executeCollect().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala new file mode 100644 index 0000000000000..4f4c1f28564cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ + +class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + } + + ignore("sort followed by limit should not leak memory") { + // TODO: this test is going to fail until we implement a proper iterator interface + // with a close() method. + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sort followed by limit") { + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") + try { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } finally { + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + + } + } + + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkThatPlansAgree( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + + // Test sorting on different data types + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType) + if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1000)(randomDataGenerator()).filter { + case d: Double => !d.isNaN + case f: Float => !java.lang.Float.isNaN(f) + case x => true + } + val inputDf = TestSQLContext.createDataFrame( + TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + checkThatPlansAgree( + inputDf, + UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 5707d2fb300ae..2c27da596bc4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} @@ -41,23 +42,23 @@ class OuterJoinSuite extends SparkPlanTest { val condition = Some(LessThan('b, 'd)) test("shuffled hash outer join") { - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), Seq( (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - )) + ).map(Row.fromTuple)) - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), Seq( (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ).map(Row.fromTuple)) - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), Seq( (1, 2.0, null, null), @@ -65,24 +66,24 @@ class OuterJoinSuite extends SparkPlanTest { (3, 3.0, null, null), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ).map(Row.fromTuple)) } test("broadcast hash outer join") { - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), Seq( (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - )) + ).map(Row.fromTuple)) - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), Seq( (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ).map(Row.fromTuple)) } }