From 732b5881b6a15d9a6f95c67a6ad027a35948197d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 10 Aug 2015 18:34:59 +0800 Subject: [PATCH 1/3] Support off-heap buffer in UnsafeInMemorySorter. --- .../apache/spark/util/collection/TimSort.java | 19 ++--- .../unsafe/sort/UnsafeExternalSorter.java | 3 +- .../unsafe/sort/UnsafeInMemorySorter.java | 70 +++++++++++++----- .../unsafe/sort/UnsafeSortDataFormat.java | 72 +++++++++++++------ .../util/collection/SortDataFormat.scala | 9 +++ .../sort/UnsafeInMemorySorterSuite.java | 8 ++- .../sql/execution/UnsafeKVExternalSorter.java | 3 +- .../apache/spark/unsafe/array/LongArray.java | 21 +++++- 8 files changed, 149 insertions(+), 56 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index a90cc0e761f62..6752e87a7fc1e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -17,6 +17,7 @@ package org.apache.spark.util.collection; +import java.io.IOException; import java.util.Comparator; /** @@ -98,7 +99,7 @@ public TimSort(SortDataFormat sortDataFormat) { * * @author Josh Bloch */ - public void sort(Buffer a, int lo, int hi, Comparator c) { + public void sort(Buffer a, int lo, int hi, Comparator c) throws IOException { assert c != null; int nRemaining = hi - lo; @@ -164,7 +165,7 @@ public void sort(Buffer a, int lo, int hi, Comparator c) { * @param c comparator to used for the sort */ @SuppressWarnings("fallthrough") - private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { + private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) throws IOException { assert lo <= start && start <= hi; if (start == lo) start++; @@ -373,7 +374,7 @@ private class SortState { * @param a the array to be sorted * @param c the comparator to determine the order of the sort */ - private SortState(Buffer a, Comparator c, int len) { + private SortState(Buffer a, Comparator c, int len) throws IOException { this.aLength = len; this.a = a; this.c = c; @@ -422,7 +423,7 @@ private void pushRun(int runBase, int runLen) { * so the invariants are guaranteed to hold for i < stackSize upon * entry to the method. */ - private void mergeCollapse() { + private void mergeCollapse() throws IOException { while (stackSize > 1) { int n = stackSize - 2; if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) @@ -440,7 +441,7 @@ private void mergeCollapse() { * Merges all runs on the stack until only one remains. This method is * called once, to complete the sort. */ - private void mergeForceCollapse() { + private void mergeForceCollapse() throws IOException { while (stackSize > 1) { int n = stackSize - 2; if (n > 0 && runLen[n - 1] < runLen[n + 1]) @@ -456,7 +457,7 @@ private void mergeForceCollapse() { * * @param i stack index of the first of the two runs to merge */ - private void mergeAt(int i) { + private void mergeAt(int i) throws IOException { assert stackSize >= 2; assert i >= 0; assert i == stackSize - 2 || i == stackSize - 3; @@ -673,7 +674,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator * (must be aBase + aLen) * @param len2 length of second run to be merged (must be > 0) */ - private void mergeLo(int base1, int len1, int base2, int len2) { + private void mergeLo(int base1, int len1, int base2, int len2) throws IOException { assert len1 > 0 && len2 > 0 && base1 + len1 == base2; // Copy first run into temp array @@ -793,7 +794,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) { * (must be aBase + aLen) * @param len2 length of second run to be merged (must be > 0) */ - private void mergeHi(int base1, int len1, int base2, int len2) { + private void mergeHi(int base1, int len1, int base2, int len2) throws IOException { assert len1 > 0 && len2 > 0 && base1 + len1 == base2; // Copy second run into temp array @@ -914,7 +915,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) { * @param minCapacity the minimum required capacity of the tmp array * @return tmp, whether or not it grew */ - private Buffer ensureCapacity(int minCapacity) { + private Buffer ensureCapacity(int minCapacity) throws IOException { if (tmpLength < minCapacity) { // Compute smallest power of 2 > minCapacity int newSize = minCapacity; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5ebbf9b068fd6..de82521ef9765 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -171,7 +171,8 @@ private void initializeForWriting() throws IOException { } this.inMemSorter = - new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); + new UnsafeInMemorySorter(taskMemoryManager, shuffleMemoryManager, recordComparator, + prefixComparator, initialSize); this.isInMemSorterExternal = false; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 1e4b8a116e11a..859b27608ffdf 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -18,9 +18,13 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; +import java.io.IOException; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; /** @@ -63,14 +67,15 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { } private final TaskMemoryManager memoryManager; - private final Sorter sorter; + private final ShuffleMemoryManager shuffleMemoryManager; + private final Sorter sorter; private final Comparator sortComparator; /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ - private long[] pointerArray; + private LongArray pointerArray; /** * The position in the sort buffer where new records can be inserted. @@ -79,16 +84,34 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { public UnsafeInMemorySorter( final TaskMemoryManager memoryManager, + final ShuffleMemoryManager shuffleMemoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - int initialSize) { + int initialSize) throws IOException { assert (initialSize > 0); - this.pointerArray = new long[initialSize * 2]; this.memoryManager = memoryManager; - this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); + this.shuffleMemoryManager = shuffleMemoryManager; + this.pointerArray = allocateLongArray(initialSize); + this.sorter = new Sorter<>(new UnsafeSortDataFormat(memoryManager, shuffleMemoryManager)); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); } + private LongArray allocateLongArray(int size) throws IOException { + MemoryBlock page = allocateMemoryBlock(size * 2); + return new LongArray(page); + } + + private MemoryBlock allocateMemoryBlock(int size) throws IOException { + long memoryToAcquire = size * LongArray.WIDTH; + final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryToAcquire); + if (memoryGranted != memoryToAcquire) { + shuffleMemoryManager.release(memoryGranted); + throw new IOException("Unable to acquire " + memoryToAcquire + " bytes of memory"); + } + MemoryBlock page = memoryManager.allocatePage(memoryToAcquire); + return page; + } + /** * @return the number of records that have been inserted into this sorter. */ @@ -97,7 +120,7 @@ public int numRecords() { } public long getMemoryUsage() { - return pointerArray.length * 8L; + return pointerArray.memoryBlock().size(); } static long getMemoryRequirementsForPointerArray(long numEntries) { @@ -105,15 +128,24 @@ static long getMemoryRequirementsForPointerArray(long numEntries) { } public boolean hasSpaceForAnotherRecord() { - return pointerArrayInsertPosition + 2 < pointerArray.length; + return pointerArrayInsertPosition + 2 < pointerArray.size(); + } + + private void releasedPointerArray(LongArray array) { + if (array != null) { + memoryManager.freePage(array.memoryBlock()); + shuffleMemoryManager.release(array.memoryBlock().size()); + array = null; + } } - public void expandPointerArray() { - final long[] oldArray = pointerArray; + public void expandPointerArray() throws IOException { + final LongArray oldArray = pointerArray; // Guard against overflow: - final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; - pointerArray = new long[newLength]; - System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + final int newLength = oldArray.size() * 2 > 0 ? (int)(oldArray.size() * 2) : Integer.MAX_VALUE; + pointerArray = allocateLongArray(newLength); + pointerArray.copyFrom(oldArray); + releasedPointerArray(oldArray); } /** @@ -123,13 +155,13 @@ public void expandPointerArray() { * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. * @param keyPrefix a user-defined key prefix */ - public void insertRecord(long recordPointer, long keyPrefix) { + public void insertRecord(long recordPointer, long keyPrefix) throws IOException { if (!hasSpaceForAnotherRecord()) { expandPointerArray(); } - pointerArray[pointerArrayInsertPosition] = recordPointer; + pointerArray.set(pointerArrayInsertPosition, recordPointer); pointerArrayInsertPosition++; - pointerArray[pointerArrayInsertPosition] = keyPrefix; + pointerArray.set(pointerArrayInsertPosition, keyPrefix); pointerArrayInsertPosition++; } @@ -137,7 +169,7 @@ public static final class SortedIterator extends UnsafeSorterIterator { private final TaskMemoryManager memoryManager; private final int sortBufferInsertPosition; - private final long[] sortBuffer; + private final LongArray sortBuffer; private int position = 0; private Object baseObject; private long baseOffset; @@ -147,7 +179,7 @@ public static final class SortedIterator extends UnsafeSorterIterator { private SortedIterator( TaskMemoryManager memoryManager, int sortBufferInsertPosition, - long[] sortBuffer) { + LongArray sortBuffer) { this.memoryManager = memoryManager; this.sortBufferInsertPosition = sortBufferInsertPosition; this.sortBuffer = sortBuffer; @@ -161,11 +193,11 @@ public boolean hasNext() { @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = sortBuffer[position]; + final long recordPointer = sortBuffer.get(position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); - keyPrefix = sortBuffer[position + 1]; + keyPrefix = sortBuffer.get(position + 1); position += 2; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d09c728a7a638..5bf9970fd343c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -17,23 +17,34 @@ package org.apache.spark.util.collection.unsafe.sort; +import java.io.IOException; + +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.SortDataFormat; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; /** * Supports sorting an array of (record pointer, key prefix) pairs. * Used in {@link UnsafeInMemorySorter}. *

- * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * Within each LongArray buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ -final class UnsafeSortDataFormat extends SortDataFormat { +final class UnsafeSortDataFormat extends SortDataFormat { - public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; - private UnsafeSortDataFormat() { } + public UnsafeSortDataFormat( + TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager) { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + } @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @@ -44,37 +55,52 @@ public RecordPointerAndKeyPrefix newKey() { } @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { - reuse.recordPointer = data[pos * 2]; - reuse.keyPrefix = data[pos * 2 + 1]; + public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data.get(pos * 2); + reuse.keyPrefix = data.get(pos * 2 + 1); return reuse; } @Override - public void swap(long[] data, int pos0, int pos1) { - long tempPointer = data[pos0 * 2]; - long tempKeyPrefix = data[pos0 * 2 + 1]; - data[pos0 * 2] = data[pos1 * 2]; - data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; - data[pos1 * 2] = tempPointer; - data[pos1 * 2 + 1] = tempKeyPrefix; + public void swap(LongArray data, int pos0, int pos1) { + long tempPointer = data.get(pos0 * 2); + long tempKeyPrefix = data.get(pos0 * 2 + 1); + data.set(pos0 * 2, data.get(pos1 * 2)); + data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1)); + data.set(pos1 * 2, tempPointer); + data.set(pos1 * 2 + 1, tempKeyPrefix); } @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos * 2] = src[srcPos * 2]; - dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos * 2, src.get(srcPos * 2)); + dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1)); } @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + dst.copyFrom(src, srcPos * 2, dstPos * 2, length * 2); + } + + private LongArray allocateLongArray(int size) throws IOException { + MemoryBlock page = allocateMemoryBlock(size * 2); + return new LongArray(page); + } + + private MemoryBlock allocateMemoryBlock(int size) throws IOException { + long memoryToAcquire = size * LongArray.WIDTH; + final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryToAcquire); + if (memoryGranted != memoryToAcquire) { + shuffleMemoryManager.release(memoryGranted); + throw new IOException("Unable to acquire " + memoryToAcquire + " bytes of memory"); + } + MemoryBlock page = memoryManager.allocatePage(memoryToAcquire); + return page; } @Override - public long[] allocate(int length) { + public LongArray allocate(int length) throws IOException { assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; - return new long[length * 2]; + return allocateLongArray(length); } - } diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala index 9a7a5a4e74868..6a1118d00070d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.collection +import java.io.IOException + import scala.reflect.ClassTag /** @@ -71,7 +73,14 @@ abstract class SortDataFormat[K, Buffer] { * Allocates a Buffer that can hold up to 'length' elements. * All elements of the buffer should be considered invalid until data is explicitly copied in. */ + @throws(classOf[IOException]) def allocate(length: Int): Buffer + + /** + * Releases a previously allocated Buffer. + * By default, it does nothing. + */ + def release(buffer: Buffer): Unit = {} } /** diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 909500930539c..d7077367a6e33 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -26,6 +26,7 @@ import static org.mockito.Mockito.mock; import org.apache.spark.HashPartitioner; +import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; @@ -45,9 +46,11 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, } @Test - public void testSortingEmptyInput() { + public void testSortingEmptyInput() throws Exception { + ShuffleMemoryManager manager = ShuffleMemoryManager.createForTesting(10000L); final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + manager, mock(RecordComparator.class), mock(PrefixComparator.class), 100); @@ -107,7 +110,8 @@ public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; } }; - UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + ShuffleMemoryManager manager = ShuffleMemoryManager.createForTesting(10000L); + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, manager, recordComparator, prefixComparator, dataToSort.length); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 69d6784713a24..1cd5e8ddfc528 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -86,7 +86,8 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize, // we will use 1 as its initial size if the map is empty. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); + taskMemoryManager, shuffleMemoryManager, recordComparator, + prefixComparator, Math.max(1, map.numElements())); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 18d1f0d2d7eb2..81731d3227668 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -30,7 +30,7 @@ public final class LongArray { // This is a long so that we perform long multiplications when computing offsets. - private static final long WIDTH = 8; + public static final long WIDTH = 8; private final MemoryBlock memory; private final Object baseObj; @@ -58,6 +58,25 @@ public long size() { return length; } + /** + * Copy the elements from another LongArray to this array. + * The length of another array must be equal or less than this array. + */ + public void copyFrom(LongArray that) { + assert that.length <= this.length: "Can't copy from a larger array"; + PlatformDependent.copyMemory( + that.baseObj, that.baseOffset, this.baseObj, this.baseOffset, that.memory.size()); + } + + /** + * Copy the elements in a range from another LongArray to this array. + */ + public void copyFrom(LongArray that, int srcPos, int dstPos, int length) { + PlatformDependent.copyMemory( + that.baseObj, that.baseOffset + srcPos * WIDTH, + this.baseObj, this.baseOffset + dstPos * WIDTH, length * WIDTH); + } + /** * Sets the value at position {@code index}. */ From bdbfc8045f1958527ded9603a20871ac48516b3d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Aug 2015 15:01:09 +0800 Subject: [PATCH 2/3] Release allocated memory. --- .../org/apache/spark/util/collection/TimSort.java | 9 +++++++++ .../collection/unsafe/sort/UnsafeExternalSorter.java | 12 ++---------- .../collection/unsafe/sort/UnsafeInMemorySorter.java | 8 ++++++-- .../collection/unsafe/sort/UnsafeSortDataFormat.java | 8 ++++++++ 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 6752e87a7fc1e..159ed017d248e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -144,6 +144,7 @@ public void sort(Buffer a, int lo, int hi, Comparator c) throws IOExc assert lo == hi; sortState.mergeForceCollapse(); assert sortState.stackSize == 1; + sortState.release(); } /** @@ -213,6 +214,7 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator>> 1); + s.release(tmp); tmp = s.allocate(newSize); tmpLength = newSize; } return tmp; } + + public void release() { + if (tmp != null) { + s.release(tmp); + } + } } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index de82521ef9765..00032526368eb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -162,14 +162,6 @@ public BoxedUnit apply() { */ private void initializeForWriting() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); - final long pointerArrayMemory = - UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pointerArrayMemory); - if (memoryAcquired != pointerArrayMemory) { - shuffleMemoryManager.release(memoryAcquired); - throw new IOException("Could not acquire " + pointerArrayMemory + " bytes of memory"); - } - this.inMemSorter = new UnsafeInMemorySorter(taskMemoryManager, shuffleMemoryManager, recordComparator, prefixComparator, initialSize); @@ -270,8 +262,8 @@ private long freeMemory() { if (inMemSorter != null) { if (!isInMemSorterExternal) { long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.releaseMemory(); memoryFreed += sorterMemoryUsage; - shuffleMemoryManager.release(sorterMemoryUsage); } inMemSorter = null; } @@ -320,8 +312,8 @@ private void growPointerArrayIfNecessary() throws IOException { shuffleMemoryManager.release(memoryAcquired); spill(); } else { + shuffleMemoryManager.release(memoryAcquired); inMemSorter.expandPointerArray(); - shuffleMemoryManager.release(oldPointerArrayMemoryUsage); } } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 859b27608ffdf..012bea1b9b00a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -131,6 +131,10 @@ public boolean hasSpaceForAnotherRecord() { return pointerArrayInsertPosition + 2 < pointerArray.size(); } + public void releaseMemory() { + releasedPointerArray(pointerArray); + } + private void releasedPointerArray(LongArray array) { if (array != null) { memoryManager.freePage(array.memoryBlock()); @@ -142,8 +146,8 @@ private void releasedPointerArray(LongArray array) { public void expandPointerArray() throws IOException { final LongArray oldArray = pointerArray; // Guard against overflow: - final int newLength = oldArray.size() * 2 > 0 ? (int)(oldArray.size() * 2) : Integer.MAX_VALUE; - pointerArray = allocateLongArray(newLength); + final int newSize = oldArray.size() * 2 > 0 ? (int)(oldArray.size() * 2) : Integer.MAX_VALUE; + pointerArray = allocateLongArray(newSize / 2); pointerArray.copyFrom(oldArray); releasedPointerArray(oldArray); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index 5bf9970fd343c..ae97ed77e4177 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -103,4 +103,12 @@ public LongArray allocate(int length) throws IOException { assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; return allocateLongArray(length); } + + @Override + public void release(LongArray array) { + if (array != null) { + memoryManager.freePage(array.memoryBlock()); + shuffleMemoryManager.release(array.memoryBlock().size()); + } + } } From 0d7c2c5c2bdce8a2386efa0104cea4522106eb9e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Aug 2015 19:03:35 +0800 Subject: [PATCH 3/3] Fix some tests. --- .../apache/spark/sql/execution/UnsafeKVExternalSorter.java | 1 + .../apache/spark/sql/execution/TestShuffleMemoryManager.scala | 4 ++++ .../spark/sql/execution/UnsafeKVExternalSorterSuite.scala | 4 +++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 1cd5e8ddfc528..2f460de3500de 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -125,6 +125,7 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, inMemSorter); sorter.spill(); + inMemSorter.releaseMemory(); map.free(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala index 48c3938ff87ba..4429b9798b480 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -48,4 +48,8 @@ class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1 def markAsOutOfMemory(): Unit = { oom = true } + + def resetOutOfMemory(): Unit = { + oom = false + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index a9515a03acf2c..ea8d3e7938a08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -135,7 +135,9 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { sorter.closeCurrentPage() } } - + // In order to allocate memory in UnsafeInMemorySorter, + // we need to reset OOM in TestShuffleMemoryManager + shuffleMemMgr.resetOutOfMemory() // Collect the sorted output val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)] val iter = sorter.sortedIterator()