sortComparator;
+
+ /**
+ * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the sort buffer where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public UnsafeInMemorySorter(
+ final TaskMemoryManager memoryManager,
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize * 2];
+ this.memoryManager = memoryManager;
+ this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
+ this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ }
+
+ /**
+ * @return the number of records that have been inserted into this sorter.
+ */
+ public int numRecords() {
+ return pointerArrayInsertPosition / 2;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 2 < pointerArray.length;
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ /**
+ * Inserts a record to be sorted. Assumes that the record pointer points to a record length
+ * stored as a 4-byte integer, followed by the record's bytes.
+ *
+ * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
+ * @param keyPrefix a user-defined key prefix
+ */
+ public void insertRecord(long recordPointer, long keyPrefix) {
+ if (!hasSpaceForAnotherRecord()) {
+ expandPointerArray();
+ }
+ pointerArray[pointerArrayInsertPosition] = recordPointer;
+ pointerArrayInsertPosition++;
+ pointerArray[pointerArrayInsertPosition] = keyPrefix;
+ pointerArrayInsertPosition++;
+ }
+
+ private static final class SortedIterator extends UnsafeSorterIterator {
+
+ private final TaskMemoryManager memoryManager;
+ private final int sortBufferInsertPosition;
+ private final long[] sortBuffer;
+ private int position = 0;
+ private Object baseObject;
+ private long baseOffset;
+ private long keyPrefix;
+ private int recordLength;
+
+ SortedIterator(
+ TaskMemoryManager memoryManager,
+ int sortBufferInsertPosition,
+ long[] sortBuffer) {
+ this.memoryManager = memoryManager;
+ this.sortBufferInsertPosition = sortBufferInsertPosition;
+ this.sortBuffer = sortBuffer;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return position < sortBufferInsertPosition;
+ }
+
+ @Override
+ public void loadNext() {
+ // This pointer points to a 4-byte record length, followed by the record's bytes
+ final long recordPointer = sortBuffer[position];
+ baseObject = memoryManager.getPage(recordPointer);
+ baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
+ recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4);
+ keyPrefix = sortBuffer[position + 1];
+ position += 2;
+ }
+
+ @Override
+ public Object getBaseObject() { return baseObject; }
+
+ @Override
+ public long getBaseOffset() { return baseOffset; }
+
+ @Override
+ public int getRecordLength() { return recordLength; }
+
+ @Override
+ public long getKeyPrefix() { return keyPrefix; }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order. For efficiency, all calls to
+ * {@code next()} will return the same mutable object.
+ */
+ public UnsafeSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
+ return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
new file mode 100644
index 0000000000000..d09c728a7a638
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+/**
+ * Supports sorting an array of (record pointer, key prefix) pairs.
+ * Used in {@link UnsafeInMemorySorter}.
+ *
+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ */
+final class UnsafeSortDataFormat extends SortDataFormat {
+
+ public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
+
+ private UnsafeSortDataFormat() { }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix newKey() {
+ return new RecordPointerAndKeyPrefix();
+ }
+
+ @Override
+ public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
+ reuse.recordPointer = data[pos * 2];
+ reuse.keyPrefix = data[pos * 2 + 1];
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ long tempPointer = data[pos0 * 2];
+ long tempKeyPrefix = data[pos0 * 2 + 1];
+ data[pos0 * 2] = data[pos1 * 2];
+ data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
+ data[pos1 * 2] = tempPointer;
+ data[pos1 * 2 + 1] = tempKeyPrefix;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos * 2] = src[srcPos * 2];
+ dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
+ return new long[length * 2];
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
new file mode 100644
index 0000000000000..16ac2e8d821ba
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+
+public abstract class UnsafeSorterIterator {
+
+ public abstract boolean hasNext();
+
+ public abstract void loadNext() throws IOException;
+
+ public abstract Object getBaseObject();
+
+ public abstract long getBaseOffset();
+
+ public abstract int getRecordLength();
+
+ public abstract long getKeyPrefix();
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
new file mode 100644
index 0000000000000..8272c2a5be0d1
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.PriorityQueue;
+
+final class UnsafeSorterSpillMerger {
+
+ private final PriorityQueue priorityQueue;
+
+ public UnsafeSorterSpillMerger(
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ final int numSpills) {
+ final Comparator comparator = new Comparator() {
+
+ @Override
+ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
+ final int prefixComparisonResult =
+ prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
+ if (prefixComparisonResult == 0) {
+ return recordComparator.compare(
+ left.getBaseObject(), left.getBaseOffset(),
+ right.getBaseObject(), right.getBaseOffset());
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ };
+ priorityQueue = new PriorityQueue(numSpills, comparator);
+ }
+
+ public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ }
+ priorityQueue.add(spillReader);
+ }
+
+ public UnsafeSorterIterator getSortedIterator() throws IOException {
+ return new UnsafeSorterIterator() {
+
+ private UnsafeSorterIterator spillReader;
+
+ @Override
+ public boolean hasNext() {
+ return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ if (spillReader != null) {
+ if (spillReader.hasNext()) {
+ spillReader.loadNext();
+ priorityQueue.add(spillReader);
+ }
+ }
+ spillReader = priorityQueue.remove();
+ }
+
+ @Override
+ public Object getBaseObject() { return spillReader.getBaseObject(); }
+
+ @Override
+ public long getBaseOffset() { return spillReader.getBaseOffset(); }
+
+ @Override
+ public int getRecordLength() { return spillReader.getRecordLength(); }
+
+ @Override
+ public long getKeyPrefix() { return spillReader.getKeyPrefix(); }
+ };
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
new file mode 100644
index 0000000000000..29e9e0f30f934
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.*;
+
+import com.google.common.io.ByteStreams;
+
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
+ * of the file format).
+ */
+final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+
+ private InputStream in;
+ private DataInputStream din;
+
+ // Variables that change with every record read:
+ private int recordLength;
+ private long keyPrefix;
+ private int numRecordsRemaining;
+
+ private byte[] arr = new byte[1024 * 1024];
+ private Object baseObject = arr;
+ private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
+
+ public UnsafeSorterSpillReader(
+ BlockManager blockManager,
+ File file,
+ BlockId blockId) throws IOException {
+ assert (file.length() > 0);
+ final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
+ this.in = blockManager.wrapForCompression(blockId, bs);
+ this.din = new DataInputStream(this.in);
+ numRecordsRemaining = din.readInt();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (numRecordsRemaining > 0);
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ recordLength = din.readInt();
+ keyPrefix = din.readLong();
+ if (recordLength > arr.length) {
+ arr = new byte[recordLength];
+ baseObject = arr;
+ }
+ ByteStreams.readFully(in, arr, 0, recordLength);
+ numRecordsRemaining--;
+ if (numRecordsRemaining == 0) {
+ in.close();
+ in = null;
+ din = null;
+ }
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return baseObject;
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return baseOffset;
+ }
+
+ @Override
+ public int getRecordLength() {
+ return recordLength;
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return keyPrefix;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
new file mode 100644
index 0000000000000..b8d66659804ad
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Tuple2;
+
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.BlockObjectWriter;
+import org.apache.spark.storage.TempLocalBlockId;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Spills a list of sorted records to disk. Spill files have the following format:
+ *
+ * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
+ */
+final class UnsafeSorterSpillWriter {
+
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array.
+ private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ private final File file;
+ private final BlockId blockId;
+ private final int numRecordsToWrite;
+ private BlockObjectWriter writer;
+ private int numRecordsSpilled = 0;
+
+ public UnsafeSorterSpillWriter(
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ int numRecordsToWrite) throws IOException {
+ final Tuple2 spilledFileInfo =
+ blockManager.diskBlockManager().createTempLocalBlock();
+ this.file = spilledFileInfo._2();
+ this.blockId = spilledFileInfo._1();
+ this.numRecordsToWrite = numRecordsToWrite;
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ writer = blockManager.getDiskWriter(
+ blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics);
+ // Write the number of records
+ writeIntToBuffer(numRecordsToWrite, 0);
+ writer.write(writeBuffer, 0, 4);
+ }
+
+ // Based on DataOutputStream.writeLong.
+ private void writeLongToBuffer(long v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 56);
+ writeBuffer[offset + 1] = (byte)(v >>> 48);
+ writeBuffer[offset + 2] = (byte)(v >>> 40);
+ writeBuffer[offset + 3] = (byte)(v >>> 32);
+ writeBuffer[offset + 4] = (byte)(v >>> 24);
+ writeBuffer[offset + 5] = (byte)(v >>> 16);
+ writeBuffer[offset + 6] = (byte)(v >>> 8);
+ writeBuffer[offset + 7] = (byte)(v >>> 0);
+ }
+
+ // Based on DataOutputStream.writeInt.
+ private void writeIntToBuffer(int v, int offset) throws IOException {
+ writeBuffer[offset + 0] = (byte)(v >>> 24);
+ writeBuffer[offset + 1] = (byte)(v >>> 16);
+ writeBuffer[offset + 2] = (byte)(v >>> 8);
+ writeBuffer[offset + 3] = (byte)(v >>> 0);
+ }
+
+ /**
+ * Write a record to a spill file.
+ *
+ * @param baseObject the base object / memory page containing the record
+ * @param baseOffset the base offset which points directly to the record data.
+ * @param recordLength the length of the record.
+ * @param keyPrefix a sort key prefix
+ */
+ public void write(
+ Object baseObject,
+ long baseOffset,
+ int recordLength,
+ long keyPrefix) throws IOException {
+ if (numRecordsSpilled == numRecordsToWrite) {
+ throw new IllegalStateException(
+ "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite);
+ } else {
+ numRecordsSpilled++;
+ }
+ writeIntToBuffer(recordLength, 0);
+ writeLongToBuffer(keyPrefix, 4);
+ int dataRemaining = recordLength;
+ int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len
+ long recordReadPosition = baseOffset;
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
+ PlatformDependent.copyMemory(
+ baseObject,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+ toTransfer);
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE;
+ }
+ if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) {
+ writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer));
+ }
+ writer.recordWritten();
+ }
+
+ public void close() throws IOException {
+ writer.commitAndClose();
+ writer = null;
+ writeBuffer = null;
+ }
+
+ public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
+ return new UnsafeSorterSpillReader(blockManager, file, blockId);
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
new file mode 100644
index 0000000000000..ea8755e21eb68
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.File;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.UUID;
+
+import scala.Tuple2;
+import scala.Tuple2$;
+import scala.runtime.AbstractFunction1;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeExternalSorterSuite {
+
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+
+ @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+
+ File tempDir;
+
+ private static final class CompressStream extends AbstractFunction1 {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ return stream;
+ }
+ }
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ tempDir = new File(Utils.createTempDir$default$1());
+ taskContext = mock(TaskContext.class);
+ when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
+ when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() {
+ @Override
+ public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable {
+ TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (BlockId) args[0],
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
+ .then(returnsSecondArg());
+ }
+
+ private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
+ final int[] arr = new int[] { value };
+ sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
+ }
+
+ @Test
+ public void testSortingOnlyByPrefix() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ insertNumber(sorter, 5);
+ insertNumber(sorter, 1);
+ insertNumber(sorter, 3);
+ sorter.spill();
+ insertNumber(sorter, 4);
+ sorter.spill();
+ insertNumber(sorter, 2);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(i, iter.getKeyPrefix());
+ assertEquals(4, iter.getRecordLength());
+ // TODO: read rest of value.
+ }
+
+ // TODO: test for cleanup:
+ // assert(tempDir.isEmpty)
+ }
+
+ @Test
+ public void testSortingEmptyArrays() throws Exception {
+
+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ 1024,
+ new SparkConf());
+
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.spill();
+ sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0);
+
+ UnsafeSorterIterator iter = sorter.getSortedIterator();
+
+ for (int i = 1; i <= 5; i++) {
+ iter.loadNext();
+ assertEquals(0, iter.getKeyPrefix());
+ assertEquals(0, iter.getRecordLength());
+ }
+ }
+
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
new file mode 100644
index 0000000000000..909500930539c
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Arrays;
+
+import org.junit.Test;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.*;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.mock;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeInMemorySorterSuite {
+
+ private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) {
+ final byte[] strBytes = new byte[length];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, length);
+ return new String(strBytes);
+ }
+
+ @Test
+ public void testSortingEmptyInput() {
+ final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
+ mock(RecordComparator.class),
+ mock(PrefixComparator.class),
+ 100);
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ assert(!iter.hasNext());
+ }
+
+ @Test
+ public void testSortingOnlyByIntegerPrefix() throws Exception {
+ final String[] dataToSort = new String[] {
+ "Boba",
+ "Pearls",
+ "Tapioca",
+ "Taho",
+ "Condensed Milk",
+ "Jasmine",
+ "Milk Tea",
+ "Lychee",
+ "Mango"
+ };
+ final TaskMemoryManager memoryManager =
+ new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ final Object baseObject = dataPage.getBaseObject();
+ // Write the records into the data page:
+ long position = dataPage.getBaseOffset();
+ for (String str : dataToSort) {
+ final byte[] strBytes = str.getBytes("utf-8");
+ PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+ position += 4;
+ PlatformDependent.copyMemory(
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ position,
+ strBytes.length);
+ position += strBytes.length;
+ }
+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
+ // use a dummy comparator
+ final RecordComparator recordComparator = new RecordComparator() {
+ @Override
+ public int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset) {
+ return 0;
+ }
+ };
+ // Compute key prefixes based on the records' partition ids
+ final HashPartitioner hashPartitioner = new HashPartitioner(4);
+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
+ final PrefixComparator prefixComparator = new PrefixComparator() {
+ @Override
+ public int compare(long prefix1, long prefix2) {
+ return (int) prefix1 - (int) prefix2;
+ }
+ };
+ UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
+ prefixComparator, dataToSort.length);
+ // Given a page of records, insert those records into the sorter one-by-one:
+ position = dataPage.getBaseOffset();
+ for (int i = 0; i < dataToSort.length; i++) {
+ // position now points to the start of a record (which holds its length).
+ final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position);
+ final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
+ final String str = getStringFromDataPage(baseObject, position + 4, recordLength);
+ final int partitionId = hashPartitioner.getPartition(str);
+ sorter.insertRecord(address, partitionId);
+ position += 4 + recordLength;
+ }
+ final UnsafeSorterIterator iter = sorter.getSortedIterator();
+ int iterLength = 0;
+ long prevPrefix = -1;
+ Arrays.sort(dataToSort);
+ while (iter.hasNext()) {
+ iter.loadNext();
+ final String str =
+ getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength());
+ final long keyPrefix = iter.getKeyPrefix();
+ assertThat(str, isIn(Arrays.asList(dataToSort)));
+ assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));
+ prevPrefix = keyPrefix;
+ iterLength++;
+ }
+ assertEquals(dataToSort.length, iterLength);
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
new file mode 100644
index 0000000000000..dd505dfa7d758
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort
+
+import org.scalatest.prop.PropertyChecks
+
+import org.apache.spark.SparkFunSuite
+
+class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
+
+ test("String prefix comparator") {
+
+ def testPrefixComparison(s1: String, s2: String): Unit = {
+ val s1Prefix = PrefixComparators.STRING.computePrefix(s1)
+ val s2Prefix = PrefixComparators.STRING.computePrefix(s2)
+ val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix)
+ assert(
+ (prefixComparisonResult == 0) ||
+ (prefixComparisonResult < 0 && s1 < s2) ||
+ (prefixComparisonResult > 0 && s1 > s2))
+ }
+
+ // scalastyle:off
+ val regressionTests = Table(
+ ("s1", "s2"),
+ ("abc", "世界"),
+ ("你好", "世界"),
+ ("你好123", "你好122")
+ )
+ // scalastyle:on
+
+ forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index edb7202245289..4b99030d1046f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -61,9 +61,10 @@ public final class UnsafeRow extends MutableRow {
/** A pool to hold non-primitive objects */
private ObjectPool pool;
- Object getBaseObject() { return baseObject; }
- long getBaseOffset() { return baseOffset; }
- ObjectPool getPool() { return pool; }
+ public Object getBaseObject() { return baseObject; }
+ public long getBaseOffset() { return baseOffset; }
+ public int getSizeInBytes() { return sizeInBytes; }
+ public ObjectPool getPool() { return pool; }
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
new file mode 100644
index 0000000000000..b94601cf6d818
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import scala.collection.Iterator;
+import scala.math.Ordering;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.sql.AbstractScalaRowIterator;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
+import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
+
+final class UnsafeExternalRowSorter {
+
+ /**
+ * If positive, forces records to be spilled to disk at the given frequency (measured in numbers
+ * of records). This is only intended to be used in tests.
+ */
+ private int testSpillFrequency = 0;
+
+ private long numRowsInserted = 0;
+
+ private final StructType schema;
+ private final UnsafeRowConverter rowConverter;
+ private final PrefixComputer prefixComputer;
+ private final UnsafeExternalSorter sorter;
+ private byte[] rowConversionBuffer = new byte[1024 * 8];
+
+ public static abstract class PrefixComputer {
+ abstract long computePrefix(InternalRow row);
+ }
+
+ public UnsafeExternalRowSorter(
+ StructType schema,
+ Ordering ordering,
+ PrefixComparator prefixComparator,
+ PrefixComputer prefixComputer) throws IOException {
+ this.schema = schema;
+ this.rowConverter = new UnsafeRowConverter(schema);
+ this.prefixComputer = prefixComputer;
+ final SparkEnv sparkEnv = SparkEnv.get();
+ final TaskContext taskContext = TaskContext.get();
+ sorter = new UnsafeExternalSorter(
+ taskContext.taskMemoryManager(),
+ sparkEnv.shuffleMemoryManager(),
+ sparkEnv.blockManager(),
+ taskContext,
+ new RowComparator(ordering, schema.length(), null),
+ prefixComparator,
+ 4096,
+ sparkEnv.conf()
+ );
+ }
+
+ /**
+ * Forces spills to occur every `frequency` records. Only for use in tests.
+ */
+ @VisibleForTesting
+ void setTestSpillFrequency(int frequency) {
+ assert frequency > 0 : "Frequency must be positive";
+ testSpillFrequency = frequency;
+ }
+
+ @VisibleForTesting
+ void insertRow(InternalRow row) throws IOException {
+ final int sizeRequirement = rowConverter.getSizeRequirement(row);
+ if (sizeRequirement > rowConversionBuffer.length) {
+ rowConversionBuffer = new byte[sizeRequirement];
+ }
+ final int bytesWritten = rowConverter.writeRow(
+ row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null);
+ assert (bytesWritten == sizeRequirement);
+ final long prefix = prefixComputer.computePrefix(row);
+ sorter.insertRecord(
+ rowConversionBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ sizeRequirement,
+ prefix
+ );
+ numRowsInserted++;
+ if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
+ spill();
+ }
+ }
+
+ @VisibleForTesting
+ void spill() throws IOException {
+ sorter.spill();
+ }
+
+ private void cleanupResources() {
+ sorter.freeMemory();
+ }
+
+ @VisibleForTesting
+ Iterator sort() throws IOException {
+ try {
+ final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
+ if (!sortedIterator.hasNext()) {
+ // Since we won't ever call next() on an empty iterator, we need to clean up resources
+ // here in order to prevent memory leaks.
+ cleanupResources();
+ }
+ return new AbstractScalaRowIterator() {
+
+ private final int numFields = schema.length();
+ private final UnsafeRow row = new UnsafeRow();
+
+ @Override
+ public boolean hasNext() {
+ return sortedIterator.hasNext();
+ }
+
+ @Override
+ public InternalRow next() {
+ try {
+ sortedIterator.loadNext();
+ row.pointTo(
+ sortedIterator.getBaseObject(),
+ sortedIterator.getBaseOffset(),
+ numFields,
+ sortedIterator.getRecordLength(),
+ null);
+ if (!hasNext()) {
+ row.copy(); // so that we don't have dangling pointers to freed page
+ cleanupResources();
+ }
+ return row;
+ } catch (IOException e) {
+ cleanupResources();
+ // Scala iterators don't declare any checked exceptions, so we need to use this hack
+ // to re-throw the exception:
+ PlatformDependent.throwException(e);
+ }
+ throw new RuntimeException("Exception should have been re-thrown in next()");
+ };
+ };
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+
+ public Iterator sort(Iterator inputIterator) throws IOException {
+ while (inputIterator.hasNext()) {
+ insertRow(inputIterator.next());
+ }
+ return sort();
+ }
+
+ /**
+ * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise.
+ */
+ public static boolean supportsSchema(StructType schema) {
+ // TODO: add spilling note to explain why we do this for now:
+ for (StructField field : schema.fields()) {
+ if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static final class RowComparator extends RecordComparator {
+ private final Ordering ordering;
+ private final int numFields;
+ private final ObjectPool objPool;
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+
+ public RowComparator(Ordering ordering, int numFields, ObjectPool objPool) {
+ this.numFields = numFields;
+ this.ordering = ordering;
+ this.objPool = objPool;
+ }
+
+ @Override
+ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+ row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool);
+ row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool);
+ return ordering.compare(row1, row2);
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala
new file mode 100644
index 0000000000000..cfefb13e7721e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator
+ * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to
+ * `Row` in order to work around a spurious IntelliJ compiler error.
+ */
+private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 74d933404551c..4b783e30d95e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -289,11 +289,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
}
val withSort = if (needSort) {
- if (sqlContext.conf.externalSortEnabled) {
- ExternalSort(rowOrdering, global = false, withShuffle)
- } else {
- Sort(rowOrdering, global = false, withShuffle)
- }
+ sqlContext.planner.BasicOperators.getSortOperator(
+ rowOrdering, global = false, withShuffle)
} else {
withShuffle
}
@@ -321,11 +318,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
- if (sqlContext.conf.externalSortEnabled) {
- ExternalSort(rowOrdering, global = false, child)
- } else {
- Sort(rowOrdering, global = false, child)
- }
+ sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
new file mode 100644
index 0000000000000..2dee3542d6101
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
+
+
+object SortPrefixUtils {
+
+ /**
+ * A dummy prefix comparator which always claims that prefixes are equal. This is used in cases
+ * where we don't know how to generate or compare prefixes for a SortOrder.
+ */
+ private object NoOpPrefixComparator extends PrefixComparator {
+ override def compare(prefix1: Long, prefix2: Long): Int = 0
+ }
+
+ def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
+ sortOrder.dataType match {
+ case StringType => PrefixComparators.STRING
+ case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL
+ case FloatType => PrefixComparators.FLOAT
+ case DoubleType => PrefixComparators.DOUBLE
+ case _ => NoOpPrefixComparator
+ }
+ }
+
+ def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = {
+ sortOrder.dataType match {
+ case StringType => (row: InternalRow) => {
+ PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String])
+ }
+ case BooleanType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1
+ else 0
+ }
+ case ByteType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Byte]
+ }
+ case ShortType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Short]
+ }
+ case IntegerType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Int]
+ }
+ case LongType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Long]
+ }
+ case FloatType => (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX
+ else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
+ }
+ case DoubleType => (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX
+ else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
+ }
+ case _ => (row: InternalRow) => 0L
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 59b9b553a7ae5..ce25af58b6cab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -302,6 +302,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object BasicOperators extends Strategy {
def numPartitions: Int = self.numPartitions
+ /**
+ * Picks an appropriate sort operator.
+ *
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ */
+ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = {
+ if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) {
+ execution.UnsafeExternalSort(sortExprs, global, child)
+ } else if (sqlContext.conf.externalSortEnabled) {
+ execution.ExternalSort(sortExprs, global, child)
+ } else {
+ execution.Sort(sortExprs, global, child)
+ }
+ }
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommand(r) :: Nil
@@ -313,11 +329,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.SortPartitions(sortExprs, child) =>
// This sort only sorts tuples within a partition. Its requiredDistribution will be
// an UnspecifiedDistribution.
- execution.Sort(sortExprs, global = false, planLater(child)) :: Nil
- case logical.Sort(sortExprs, global, child) if sqlContext.conf.externalSortEnabled =>
- execution.ExternalSort(sortExprs, global, planLater(child)):: Nil
+ getSortOperator(sortExprs, global = false, planLater(child)) :: Nil
case logical.Sort(sortExprs, global, child) =>
- execution.Sort(sortExprs, global, planLater(child)):: Nil
+ getSortOperator(sortExprs, global, planLater(child)):: Nil
case logical.Project(projectList, child) =>
execution.Project(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index de14e6ad79ad6..4c063c299ba53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.types.StructType
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.collection.ExternalSorter
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.{HashPartitioner, SparkEnv}
@@ -246,6 +248,77 @@ case class ExternalSort(
override def outputOrdering: Seq[SortOrder] = sortOrder
}
+/**
+ * :: DeveloperApi ::
+ * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of
+ * Project Tungsten).
+ *
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
+ * spill every `frequency` records.
+ */
+@DeveloperApi
+case class UnsafeExternalSort(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ child: SparkPlan,
+ testSpillFrequency: Int = 0)
+ extends UnaryNode {
+
+ private[this] val schema: StructType = child.schema
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
+ assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
+ def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+ val ordering = newOrdering(sortOrder, child.output)
+ val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output)
+ // Hack until we generate separate comparator implementations for ascending vs. descending
+ // (or choose to codegen them):
+ val prefixComparator = {
+ val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+ if (sortOrder.head.direction == Descending) {
+ new PrefixComparator {
+ override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2)
+ }
+ } else {
+ comp
+ }
+ }
+ val prefixComputer = {
+ val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression)
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = prefixComputer(row)
+ }
+ }
+ val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
+ if (testSpillFrequency > 0) {
+ sorter.setTestSpillFrequency(testSpillFrequency)
+ }
+ sorter.sort(iterator)
+ }
+ child.execute().mapPartitions(doSort, preservesPartitioning = true)
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
+}
+
+@DeveloperApi
+object UnsafeExternalSort {
+ /**
+ * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise.
+ */
+ def supportsSchema(schema: StructType): Boolean = {
+ UnsafeExternalRowSorter.supportsSchema(schema)
+ }
+}
+
+
/**
* :: DeveloperApi ::
* Return a new RDD that has exactly `numPartitions` partitions.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index a1e3ca11b1ad9..a2c10fdaf6cdb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
class SortSuite extends SparkPlanTest {
@@ -33,12 +34,14 @@ class SortSuite extends SparkPlanTest {
checkAnswer(
input.toDF("a", "b", "c"),
- ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
- input.sorted)
+ ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan),
+ input.sortBy(t => (t._1, t._2)).map(Row.fromTuple),
+ sortAnswers = false)
checkAnswer(
input.toDF("a", "b", "c"),
- ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
- input.sortBy(t => (t._2, t._1)))
+ ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan),
+ input.sortBy(t => (t._2, t._1)).map(Row.fromTuple),
+ sortAnswers = false)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 108b1122f7bff..6a8f394545816 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -17,18 +17,15 @@
package org.apache.spark.sql.execution
-import scala.language.implicitConversions
-import scala.reflect.runtime.universe.TypeTag
-import scala.util.control.NonFatal
-
import org.apache.spark.SparkFunSuite
-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util._
-
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame}
+import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
/**
* Base class for writing tests for individual physical operators. For an example of how this
@@ -49,12 +46,19 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
protected def checkAnswer(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer)
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ doCheckAnswer(
+ input :: Nil,
+ (plans: Seq[SparkPlan]) => planFunction(plans.head),
+ expectedAnswer,
+ sortAnswers)
}
/**
@@ -64,86 +68,131 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer(
+ protected def checkAnswer2(
left: DataFrame,
right: DataFrame,
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- checkAnswer(left :: right :: Nil,
- (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer)
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ doCheckAnswer(
+ left :: right :: Nil,
+ (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)),
+ expectedAnswer,
+ sortAnswers)
}
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
- * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
- * the physical operator that's being tested.
+ * @param planFunction a function which accepts a sequence of input SparkPlans and uses them to
+ * instantiate the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer(
+ protected def doCheckAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match {
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}
/**
- * Runs the plan and makes sure the answer matches the expected result.
+ * Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+ * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
+ * instantiate a reference implementation of the physical operator
+ * that's being tested. The result of executing this plan will be
+ * treated as the source-of-truth for the test.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer[A <: Product : TypeTag](
+ protected def checkThatPlansAgree(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(input, planFunction, expectedRows)
+ expectedPlanFunction: SparkPlan => SparkPlan,
+ sortAnswers: Boolean = true): Unit = {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
}
+}
- /**
- * Runs the plan and makes sure the answer matches the expected result.
- * @param left the left input data to be used.
- * @param right the right input data to be used.
- * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
- * the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
- */
- protected def checkAnswer[A <: Product : TypeTag](
- left: DataFrame,
- right: DataFrame,
- planFunction: (SparkPlan, SparkPlan) => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(left, right, planFunction, expectedRows)
- }
+/**
+ * Helper methods for writing tests of individual physical operators.
+ */
+object SparkPlanTest {
/**
- * Runs the plan and makes sure the answer matches the expected result.
+ * Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+ * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
+ * instantiate a reference implementation of the physical operator
+ * that's being tested. The result of executing this plan will be
+ * treated as the source-of-truth for the test.
*/
- protected def checkAnswer[A <: Product : TypeTag](
- input: Seq[DataFrame],
- planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(input, planFunction, expectedRows)
- }
+ def checkAnswer(
+ input: DataFrame,
+ planFunction: SparkPlan => SparkPlan,
+ expectedPlanFunction: SparkPlan => SparkPlan,
+ sortAnswers: Boolean): Option[String] = {
-}
+ val outputPlan = planFunction(input.queryExecution.sparkPlan)
+ val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
-/**
- * Helper methods for writing tests of individual physical operators.
- */
-object SparkPlanTest {
+ val expectedAnswer: Seq[Row] = try {
+ executePlan(expectedOutputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan to calculate expected answer:
+ | $expectedOutputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ val actualAnswer: Seq[Row] = try {
+ executePlan(outputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan:
+ | $outputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
+ s"""
+ | Results do not match.
+ | Actual result Spark plan:
+ | $outputPlan
+ | Expected result Spark plan:
+ | $expectedOutputPlan
+ | $errorMessage
+ """.stripMargin
+ }
+ }
/**
* Runs the plan and makes sure the answer matches the expected result.
@@ -151,28 +200,45 @@ object SparkPlanTest {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
def checkAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[Row]): Option[String] = {
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
- // A very simple resolver to make writing tests easier. In contrast to the real resolver
- // this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = TestSQLContext.prepareForExecution.execute(
- outputPlan transform {
- case plan: SparkPlan =>
- val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
- plan.transformExpressions {
- case UnresolvedAttribute(Seq(u)) =>
- inputMap.getOrElse(u,
- sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
- }
- }
- )
+ val sparkAnswer: Seq[Row] = try {
+ executePlan(outputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan:
+ | $outputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+ compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
+ s"""
+ | Results do not match for Spark plan:
+ | $outputPlan
+ | $errorMessage
+ """.stripMargin
+ }
+ }
+
+ private def compareAnswers(
+ sparkAnswer: Seq[Row],
+ expectedAnswer: Seq[Row],
+ sort: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -187,40 +253,43 @@ object SparkPlanTest {
case o => o
})
}
- converted.sortBy(_.toString())
- }
-
- val sparkAnswer: Seq[Row] = try {
- resolvedPlan.executeCollect().toSeq
- } catch {
- case NonFatal(e) =>
- val errorMessage =
- s"""
- | Exception thrown while executing Spark plan:
- | $outputPlan
- | == Exception ==
- | $e
- | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin
- return Some(errorMessage)
+ if (sort) {
+ converted.sortBy(_.toString())
+ } else {
+ converted
+ }
}
-
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
val errorMessage =
s"""
- | Results do not match for Spark plan:
- | $outputPlan
| == Results ==
| ${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ s"== Expected Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString()),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ s"== Actual Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
""".stripMargin
- return Some(errorMessage)
+ Some(errorMessage)
+ } else {
+ None
}
+ }
- None
+ private def executePlan(outputPlan: SparkPlan): Seq[Row] = {
+ // A very simple resolver to make writing tests easier. In contrast to the real resolver
+ // this is always case sensitive and does not try to handle scoping or complex type resolution.
+ val resolvedPlan = TestSQLContext.prepareForExecution.execute(
+ outputPlan transform {
+ case plan: SparkPlan =>
+ val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
+ plan.transformExpressions {
+ case UnresolvedAttribute(Seq(u)) =>
+ inputMap.getOrElse(u,
+ sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
+ }
+ }
+ )
+ resolvedPlan.executeCollect().toSeq
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
new file mode 100644
index 0000000000000..4f4c1f28564cb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types._
+
+class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+
+ override def beforeAll(): Unit = {
+ TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+ }
+
+ override def afterAll(): Unit = {
+ TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+ }
+
+ ignore("sort followed by limit should not leak memory") {
+ // TODO: this test is going to fail until we implement a proper iterator interface
+ // with a close() method.
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
+ (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ }
+
+ test("sort followed by limit") {
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
+ try {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
+ (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ } finally {
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
+
+ }
+ }
+
+ test("sorting does not crash for large inputs") {
+ val sortOrder = 'a.asc :: Nil
+ val stringLength = 1024 * 1024 * 2
+ checkThatPlansAgree(
+ Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
+ UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
+ Sort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+
+ // Test sorting on different data types
+ for (
+ dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
+ if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
+ nullable <- Seq(true, false);
+ sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
+ randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
+ ) {
+ test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
+ val inputData = Seq.fill(1000)(randomDataGenerator()).filter {
+ case d: Double => !d.isNaN
+ case f: Float => !java.lang.Float.isNaN(f)
+ case x => true
+ }
+ val inputDf = TestSQLContext.createDataFrame(
+ TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
+ StructType(StructField("a", dataType, nullable = true) :: Nil)
+ )
+ assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
+ checkThatPlansAgree(
+ inputDf,
+ UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23),
+ Sort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 5707d2fb300ae..2c27da596bc4f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan}
import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
@@ -41,23 +42,23 @@ class OuterJoinSuite extends SparkPlanTest {
val condition = Some(LessThan('b, 'd))
test("shuffled hash outer join") {
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
@@ -65,24 +66,24 @@ class OuterJoinSuite extends SparkPlanTest {
(3, 3.0, null, null),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
}
test("broadcast hash outer join") {
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
}
}