From 81d52c558d88f64a32ba73719da352c2365a56eb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 29 Apr 2015 13:29:29 -0700 Subject: [PATCH 01/92] WIP on UnsafeSorter --- .../unsafe/sort/UnsafeSortDataFormat.java | 93 ++++++++++++ .../spark/unsafe/sort/UnsafeSorter.java | 136 ++++++++++++++++++ .../spark/unsafe/sort/UnsafeSorterSuite.java | 7 + 3 files changed, 236 insertions(+) create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java create mode 100644 core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java new file mode 100644 index 0000000000000..6bae742e2bdab --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import org.apache.spark.util.collection.SortDataFormat; + +/** + * TODO: finish writing this description + * + * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ +final class UnsafeSortDataFormat + extends SortDataFormat { + + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + + private UnsafeSortDataFormat() { }; + + public static final class KeyPointerAndPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + long keyPrefix; + } + + @Override + public KeyPointerAndPrefix getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public KeyPointerAndPrefix newKey() { + return new KeyPointerAndPrefix(); + } + + @Override + public KeyPointerAndPrefix getKey(long[] data, int pos, KeyPointerAndPrefix reuse) { + reuse.recordPointer = data[pos * 2]; + reuse.keyPrefix = data[pos * 2 + 1]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + long tempPointer = data[pos0 * 2]; + long tempKeyPrefix = data[pos0 * 2 + 1]; + data[pos0 * 2] = data[pos1 * 2]; + data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; + data[pos1 * 2] = tempPointer; + data[pos1 * 2 + 1] = tempKeyPrefix; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos * 2] = src[srcPos * 2]; + dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; + return new long[length * 2]; + } + +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java new file mode 100644 index 0000000000000..9e8a4d707d181 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import java.util.Comparator; +import java.util.Iterator; + +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.KeyPointerAndPrefix; + +public final class UnsafeSorter { + + public static abstract class RecordComparator { + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); + } + + public static abstract class PrefixComputer { + public abstract long computePrefix(Object baseObject, long baseOffset); + } + + /** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific comparisons, + * such as lexicographic comparison for strings. + */ + public static abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); + } + + private final TaskMemoryManager memoryManager; + private final PrefixComputer prefixComputer; + private final Sorter sorter; + private final Comparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ + private long[] sortBuffer = new long[1024]; + + private int sortBufferInsertPosition = 0; + + private void expandSortBuffer(int newSize) { + assert (newSize > sortBuffer.length); + final long[] oldBuffer = sortBuffer; + sortBuffer = new long[newSize]; + System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); + } + + public UnsafeSorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + PrefixComputer prefixComputer, + final PrefixComparator prefixComparator) { + this.memoryManager = memoryManager; + this.prefixComputer = prefixComputer; + this.sorter = + new Sorter(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new Comparator() { + @Override + public int compare(KeyPointerAndPrefix left, KeyPointerAndPrefix right) { + if (left.keyPrefix == right.keyPrefix) { + final Object leftBaseObject = memoryManager.getPage(left.recordPointer); + final long leftBaseOffset = memoryManager.getOffsetInPage(left.recordPointer); + final Object rightBaseObject = memoryManager.getPage(right.recordPointer); + final long rightBaseOffset = memoryManager.getOffsetInPage(right.recordPointer); + return recordComparator.compare( + leftBaseObject, leftBaseOffset, rightBaseObject, rightBaseOffset); + } else { + return prefixComparator.compare(left.keyPrefix, right.keyPrefix); + } + } + }; + } + + public void insertRecord(long objectAddress) { + if (sortBufferInsertPosition + 2 == sortBuffer.length) { + expandSortBuffer(sortBuffer.length * 2); + } + final Object baseObject = memoryManager.getPage(objectAddress); + final long baseOffset = memoryManager.getOffsetInPage(objectAddress); + final long keyPrefix = prefixComputer.computePrefix(baseObject, baseOffset); + sortBuffer[sortBufferInsertPosition] = objectAddress; + sortBuffer[sortBufferInsertPosition + 1] = keyPrefix; + sortBufferInsertPosition += 2; + } + + public Iterator getSortedIterator() { + final MemoryLocation memoryLocation = new MemoryLocation(); + sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator); + return new Iterator() { + int position = 0; + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public MemoryLocation next() { + final long address = sortBuffer[position]; + position += 2; + final Object baseObject = memoryManager.getPage(address); + final long baseOffset = memoryManager.getOffsetInPage(address); + memoryLocation.setObjAndOffset(baseObject, baseOffset); + return memoryLocation; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + +} diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java new file mode 100644 index 0000000000000..f96c8ebd723c9 --- /dev/null +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -0,0 +1,7 @@ +package org.apache.spark.unsafe.sort; + +/** + * Created by joshrosen on 4/29/15. + */ +public class UnsafeSorterSuite { +} From abf7bfe4ddbb2603272ef3926776ceefcc07ff7f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 29 Apr 2015 14:34:15 -0700 Subject: [PATCH 02/92] Add basic test case. --- .../unsafe/sort/UnsafeSortDataFormat.java | 18 +-- .../spark/unsafe/sort/UnsafeSorter.java | 33 +++-- .../spark/unsafe/sort/UnsafeSorterSuite.java | 136 +++++++++++++++++- 3 files changed, 157 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java index 6bae742e2bdab..9955e3fcaabbb 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe.sort; +import static org.apache.spark.unsafe.sort.UnsafeSorter.KeyPointerAndPrefix; import org.apache.spark.util.collection.SortDataFormat; /** @@ -26,24 +27,11 @@ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ final class UnsafeSortDataFormat - extends SortDataFormat { + extends SortDataFormat { public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); - private UnsafeSortDataFormat() { }; - - public static final class KeyPointerAndPrefix { - /** - * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a - * description of how these addresses are encoded. - */ - long recordPointer; - - /** - * A key prefix, for use in comparisons. - */ - long keyPrefix; - } + private UnsafeSortDataFormat() { } @Override public KeyPointerAndPrefix getKey(long[] data, int pos) { diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 9e8a4d707d181..6da89004d2f53 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -20,13 +20,24 @@ import java.util.Comparator; import java.util.Iterator; -import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.KeyPointerAndPrefix; public final class UnsafeSorter { + public static final class KeyPointerAndPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + long keyPrefix; + } + public static abstract class RecordComparator { public abstract int compare( Object leftBaseObject, @@ -105,11 +116,11 @@ public void insertRecord(long objectAddress) { sortBufferInsertPosition += 2; } - public Iterator getSortedIterator() { - final MemoryLocation memoryLocation = new MemoryLocation(); + public Iterator getSortedIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator); - return new Iterator() { - int position = 0; + return new Iterator() { + private int position = 0; + private final KeyPointerAndPrefix keyPointerAndPrefix = new KeyPointerAndPrefix(); @Override public boolean hasNext() { @@ -117,13 +128,11 @@ public boolean hasNext() { } @Override - public MemoryLocation next() { - final long address = sortBuffer[position]; + public KeyPointerAndPrefix next() { + keyPointerAndPrefix.recordPointer = sortBuffer[position]; + keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1]; position += 2; - final Object baseObject = memoryManager.getPage(address); - final long baseOffset = memoryManager.getOffsetInPage(address); - memoryLocation.setObjAndOffset(baseObject, baseOffset); - return memoryLocation; + return keyPointerAndPrefix; } @Override diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index f96c8ebd723c9..c22edfb412e1b 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -1,7 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.unsafe.sort; -/** - * Created by joshrosen on 4/29/15. - */ +import java.util.Arrays; +import java.util.Iterator; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + public class UnsafeSorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset) { + final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + final byte[] strBytes = new byte[strLength]; + PlatformDependent.UNSAFE.copyMemory( + baseObject, + baseOffset + 8, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + /** + * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. + */ + @Test + public void testSortingOnlyByPartitionId() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); + position += 8; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + final UnsafeSorter.PrefixComputer prefixComputer = new UnsafeSorter.PrefixComputer() { + @Override + public long computePrefix(Object baseObject, long baseOffset) { + final String str = getStringFromDataPage(baseObject, baseOffset); + final int partitionId = hashPartitioner.getPartition(str); + return (long) partitionId; + } + }; + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + final UnsafeSorter sorter = + new UnsafeSorter(memoryManager, recordComparator, prefixComputer, prefixComparator); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + sorter.insertRecord(address); + position += 8 + recordLength; + } + final Iterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + final UnsafeSorter.KeyPointerAndPrefix pointerAndPrefix = iter.next(); + final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); + final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); + final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); + Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); + Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " + + prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix); + prevPrefix = pointerAndPrefix.keyPrefix; + iterLength++; + } + Assert.assertEquals(dataToSort.length, iterLength); + } } From 57a4ea08c2f415f1fd63167b090c8567ef91ec2e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 29 Apr 2015 21:16:17 -0700 Subject: [PATCH 03/92] Make initialSize configurable in UnsafeSorter --- .../java/org/apache/spark/unsafe/sort/UnsafeSorter.java | 7 +++++-- .../org/apache/spark/unsafe/sort/UnsafeSorterSuite.java | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 6da89004d2f53..517d737368864 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -67,7 +67,7 @@ public static abstract class PrefixComparator { * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ - private long[] sortBuffer = new long[1024]; + private long[] sortBuffer; private int sortBufferInsertPosition = 0; @@ -82,7 +82,10 @@ public UnsafeSorter( final TaskMemoryManager memoryManager, final RecordComparator recordComparator, PrefixComputer prefixComputer, - final PrefixComparator prefixComparator) { + final PrefixComparator prefixComparator, + int initialSize) { + assert (initialSize > 0); + this.sortBuffer = new long[initialSize * 2]; this.memoryManager = memoryManager; this.prefixComputer = prefixComputer; this.sorter = diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index c22edfb412e1b..95d54ea79cf40 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -106,8 +106,8 @@ public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; } }; - final UnsafeSorter sorter = - new UnsafeSorter(memoryManager, recordComparator, prefixComputer, prefixComparator); + final UnsafeSorter sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComputer, + prefixComparator, dataToSort.length); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { From e90015245c328755d26d185d6c13f15c1c7f30f7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Apr 2015 17:02:18 -0700 Subject: [PATCH 04/92] Add test for empty iterator in UnsafeSorter --- .../apache/spark/unsafe/sort/UnsafeSorterSuite.java | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index 95d54ea79cf40..52bb27a5b572f 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -22,6 +22,7 @@ import org.junit.Assert; import org.junit.Test; +import static org.mockito.Mockito.*; import org.apache.spark.HashPartitioner; import org.apache.spark.unsafe.PlatformDependent; @@ -43,6 +44,18 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset) return new String(strBytes); } + @Test + public void testSortingEmptyInput() { + final UnsafeSorter sorter = new UnsafeSorter( + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + mock(UnsafeSorter.RecordComparator.class), + mock(UnsafeSorter.PrefixComputer.class), + mock(UnsafeSorter.PrefixComparator.class), + 100); + final Iterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + /** * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. */ From 767d3cad606b47dc508a189642c12eff51b29682 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Apr 2015 18:47:28 -0700 Subject: [PATCH 05/92] Fix invalid range in UnsafeSorter. TODO: write fuzz tests to uncover stuff like this. Sorting has nice invariants; should be an easy test to write. --- .../main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 517d737368864..63dcf1596c2e9 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -120,7 +120,7 @@ public void insertRecord(long objectAddress) { } public Iterator getSortedIterator() { - sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator); + sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); return new Iterator() { private int position = 0; private final KeyPointerAndPrefix keyPointerAndPrefix = new KeyPointerAndPrefix(); From 3db12de8ad6b9c724d952715aa5ad38a74cb2eda Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Apr 2015 18:49:40 -0700 Subject: [PATCH 06/92] Minor simplification and sanity checks in UnsafeSorter --- .gitignore | 1 + .../spark/unsafe/sort/UnsafeSorter.java | 19 +++++++++++++++---- .../scala/org/apache/spark/SparkEnv.scala | 3 ++- .../shuffle/FileShuffleBlockManager.scala | 2 +- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d54d21b802be8..46a2a3a3f190d 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ conf/*.properties conf/*.conf conf/*.xml conf/slaves +core/build/py4j/ docs/_site docs/api target/ diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 63dcf1596c2e9..d33ca321a9835 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -30,12 +30,22 @@ public static final class KeyPointerAndPrefix { * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a * description of how these addresses are encoded. */ - long recordPointer; + public long recordPointer; /** * A key prefix, for use in comparisons. */ - long keyPrefix; + public long keyPrefix; + + @Override + public int hashCode() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(Object obj) { + throw new UnsupportedOperationException(); + } } public static abstract class RecordComparator { @@ -115,8 +125,9 @@ public void insertRecord(long objectAddress) { final long baseOffset = memoryManager.getOffsetInPage(objectAddress); final long keyPrefix = prefixComputer.computePrefix(baseObject, baseOffset); sortBuffer[sortBufferInsertPosition] = objectAddress; - sortBuffer[sortBufferInsertPosition + 1] = keyPrefix; - sortBufferInsertPosition += 2; + sortBufferInsertPosition++; + sortBuffer[sortBufferInsertPosition] = keyPrefix; + sortBufferInsertPosition++; } public Iterator getSortedIterator() { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0c4d28f786edd..8c40bc93863b2 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -313,7 +313,8 @@ object SparkEnv extends Logging { // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", + "unsafe" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index e9b4e2b955dc8..0a84fdc0e4ca2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} -/** A group of writers for a ShuffleMapTask, one writer per reducer. */ +/** A group of writers for ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { val writers: Array[BlockObjectWriter] From 4d2f5e1eb5af05e1d2e13e226192818d54ed7221 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Apr 2015 20:11:43 -0700 Subject: [PATCH 07/92] WIP --- .../shuffle/unsafe/UnsafeShuffleManager.scala | 324 ++++++++++++++++++ .../shuffle/unsafe/UnsafeShuffleSuite.scala | 30 ++ 2 files changed, 354 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala new file mode 100644 index 0000000000000..56d76304e713a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import java.io.{ByteArrayOutputStream, FileOutputStream} +import java.nio.ByteBuffer +import java.util + +import com.esotericsoftware.kryo.io.ByteBufferOutputStream +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.memory.{TaskMemoryManager, MemoryBlock} +import org.apache.spark.unsafe.sort.UnsafeSorter +import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator} +import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, TaskContext} +import org.apache.spark.shuffle._ + +private[spark] class UnsafeShuffleHandle[K, V]( + shuffleId: Int, + override val numMaps: Int, + override val dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { + require(UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) +} + +private[spark] object UnsafeShuffleManager { + def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { + dependency.aggregator.isEmpty && dependency.keyOrdering.isEmpty + } +} + +private object DummyRecordComparator extends RecordComparator { + override def compare( + leftBaseObject: scala.Any, + leftBaseOffset: Long, + rightBaseObject: scala.Any, + rightBaseOffset: Long): Int = { + 0 + } +} + +private object PartitionerPrefixComputer extends PrefixComputer { + override def computePrefix(baseObject: scala.Any, baseOffset: Long): Long = { + // TODO: should the prefix be computed when inserting the record pointer rather than being + // read from the record itself? May be more efficient in terms of space, etc, and is a simple + // change. + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset) + } +} + +private object PartitionerPrefixComparator extends PrefixComparator { + override def compare(prefix1: Long, prefix2: Long): Int = { + (prefix1 - prefix2).toInt + } +} + +private[spark] class UnsafeShuffleWriter[K, V]( + shuffleBlockManager: IndexShuffleBlockManager, + handle: UnsafeShuffleHandle[K, V], + mapId: Int, + context: TaskContext) + extends ShuffleWriter[K, V] { + + println("Construcing a new UnsafeShuffleWriter") + + private[this] val memoryManager: TaskMemoryManager = context.taskMemoryManager() + + private[this] val dep = handle.dependency + + private[this] var sorter: UnsafeSorter = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private[this] var stopping = false + + private[this] var mapStatus: MapStatus = null + + private[this] val writeMetrics = new ShuffleWriteMetrics() + context.taskMetrics().shuffleWriteMetrics = Some(writeMetrics) + + private[this] val allocatedPages: util.LinkedList[MemoryBlock] = + new util.LinkedList[MemoryBlock]() + + private[this] val blockManager = SparkEnv.get.blockManager + + /** Write a sequence of records to this task's output */ + override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + println("Opened writer!") + val serializer = Serializer.getSerializer(dep.serializer).newInstance() + val partitioner = dep.partitioner + sorter = new UnsafeSorter( + context.taskMemoryManager(), + DummyRecordComparator, + PartitionerPrefixComputer, + PartitionerPrefixComparator, + 4096 // initial size + ) + + // Pack records into data pages: + val PAGE_SIZE = 1024 * 1024 * 1 + var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE) + allocatedPages.add(currentPage) + var currentPagePosition: Long = currentPage.getBaseOffset + + // TODO make this configurable + val serArray = new Array[Byte](1024 * 1024) + val byteBuffer = ByteBuffer.wrap(serArray) + val bbos = new ByteBufferOutputStream() + bbos.setByteBuffer(byteBuffer) + val serBufferSerStream = serializer.serializeStream(bbos) + + while (records.hasNext) { + val nextRecord: Product2[K, V] = records.next() + println("Writing record " + nextRecord) + val partitionId: Int = partitioner.getPartition(nextRecord._1) + serBufferSerStream.writeObject(nextRecord) + + val sizeRequirement: Int = byteBuffer.position() + 8 + 8 + println("Size requirement in intenral buffer is " + sizeRequirement) + if (sizeRequirement > (PAGE_SIZE - currentPagePosition)) { + println("Allocating a new data page after writing " + currentPagePosition) + currentPage = memoryManager.allocatePage(PAGE_SIZE) + allocatedPages.add(currentPage) + currentPagePosition = currentPage.getBaseOffset + } + println("Before writing record, current page position is " + currentPagePosition) + // TODO: check that it's still not too large + val newRecordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition) + PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId) + currentPagePosition += 8 + PlatformDependent.UNSAFE.putLong( + currentPage.getBaseObject, currentPagePosition, byteBuffer.position()) + currentPagePosition += 8 + PlatformDependent.copyMemory( + serArray, + PlatformDependent.BYTE_ARRAY_OFFSET, + currentPage.getBaseObject, + currentPagePosition, + byteBuffer.position()) + currentPagePosition += byteBuffer.position() + println("After writing record, current page position is " + currentPagePosition) + sorter.insertRecord(newRecordAddress) + byteBuffer.position(0) + } + // TODO: free the buffers, etc, at this point since they're not needed + val sortedIterator: util.Iterator[KeyPointerAndPrefix] = sorter.getSortedIterator + // Now that the partition is sorted, write out the data to a file, keeping track off offsets + // for use in the sort-based shuffle index. + val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) + val partitionLengths = new Array[Long](partitioner.numPartitions) + // TODO: compression tests? + // TODO why is append true here? + // TODO: metrics tracking and all of the other stuff that diskblockobjectwriter would give us + // TODO: note that we saw FAILED_TO_UNCOMPRESS(5) at some points during debugging when we were + // not properly wrapping the writer for compression even though readers expected compressed + // data; the fact that someone still reported this issue in newer Spark versions suggests that + // we should audit the code to make sure wrapping is done at the right set of places and to + // check that we haven't missed any rare corner-cases / rarely-used paths. + val out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) + val serOut = serializer.serializeStream(out) + serOut.flush() + var currentPartition = -1 + var currentPartitionLength: Long = 0 + while (sortedIterator.hasNext) { + val keyPointerAndPrefix: KeyPointerAndPrefix = sortedIterator.next() + val partition = keyPointerAndPrefix.keyPrefix.toInt + println("Partition is " + partition) + if (currentPartition == -1) { + currentPartition = partition + } + if (partition != currentPartition) { + println("switching partition") + partitionLengths(currentPartition) = currentPartitionLength + currentPartitionLength = 0 + currentPartition = partition + } + val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer) + val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer) + val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8) + partitionLengths(currentPartition) += recordLength + println("Base offset is " + baseOffset) + println("Record length is " + recordLength) + var i: Int = 0 + // TODO: need to have a way to figure out whether a serializer supports relocation of + // serialized objects or not. Sandy also ran into this in his patch (see + // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might + // as well just bypass this optimized code path in favor of the old one. + while (i < recordLength) { + out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i)) + i += 1 + } + } + out.flush() + //serOut.close() + //out.flush() + out.close() + shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + println("Stopping unsafeshufflewriter") + try { + if (stopping) { + None + } else { + stopping = true + if (success) { + Option(mapStatus) + } else { + // The map task failed, so delete our output data. + shuffleBlockManager.removeDataByMap(dep.shuffleId, mapId) + None + } + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + val iter = allocatedPages.iterator() + while (iter.hasNext) { + memoryManager.freePage(iter.next()) + iter.remove() + } + val startTime = System.nanoTime() + //sorter.stop() + context.taskMetrics().shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - startTime)) + sorter = null + } + } + } +} + + + +private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager { + + private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { + println("Opening unsafeShuffleWriter") + new UnsafeShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + sortShuffleManager.getReader(handle, startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + handle match { + case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + // TODO: do we need to do anything to register the shuffle here? + new UnsafeShuffleWriter( + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockManager], + unsafeShuffleHandle, + mapId, + context) + case other => + sortShuffleManager.getWriter(handle, mapId, context) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + // TODO: need to do something here for our unsafe path + sortShuffleManager.unregisterShuffle(shuffleId) + } + + override def shuffleBlockResolver: ShuffleBlockResolver = { + sortShuffleManager.shuffleBlockResolver + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + sortShuffleManager.stop() + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala new file mode 100644 index 0000000000000..8ff3abefea897 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import org.apache.spark.ShuffleSuite +import org.scalatest.BeforeAndAfterAll + +class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. + + override def beforeAll() { + conf.set("spark.shuffle.manager", "unsafe") + } +} From 8e3ec208be6cdf4eb167bdbe6940ef5552aeb58a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 13:31:28 -0700 Subject: [PATCH 08/92] Begin code cleanup. --- .../shuffle/unsafe/UnsafeShuffleManager.scala | 149 ++++++++++-------- 1 file changed, 84 insertions(+), 65 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 56d76304e713a..0e796dfe2aefd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -17,23 +17,23 @@ package org.apache.spark.shuffle.unsafe -import java.io.{ByteArrayOutputStream, FileOutputStream} +import java.io.{FileOutputStream, OutputStream} import java.nio.ByteBuffer import java.util import com.esotericsoftware.kryo.io.ByteBufferOutputStream + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.unsafe.PlatformDependent -import org.apache.spark.unsafe.memory.{TaskMemoryManager, MemoryBlock} +import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} import org.apache.spark.unsafe.sort.UnsafeSorter import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator} -import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, TaskContext} -import org.apache.spark.shuffle._ private[spark] class UnsafeShuffleHandle[K, V]( shuffleId: Int, @@ -87,7 +87,7 @@ private[spark] class UnsafeShuffleWriter[K, V]( private[this] val dep = handle.dependency - private[this] var sorter: UnsafeSorter = null + private[this] val partitioner = dep.partitioner // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -104,52 +104,55 @@ private[spark] class UnsafeShuffleWriter[K, V]( private[this] val blockManager = SparkEnv.get.blockManager - /** Write a sequence of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - println("Opened writer!") - val serializer = Serializer.getSerializer(dep.serializer).newInstance() - val partitioner = dep.partitioner - sorter = new UnsafeSorter( + private def sortRecords(records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = { + val sorter = new UnsafeSorter( context.taskMemoryManager(), DummyRecordComparator, PartitionerPrefixComputer, PartitionerPrefixComparator, 4096 // initial size ) - - // Pack records into data pages: + val serializer = Serializer.getSerializer(dep.serializer).newInstance() val PAGE_SIZE = 1024 * 1024 * 1 + var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE) - allocatedPages.add(currentPage) var currentPagePosition: Long = currentPage.getBaseOffset - // TODO make this configurable + def ensureSpaceInDataPage(spaceRequired: Long): Unit = { + if (spaceRequired > PAGE_SIZE) { + throw new Exception(s"Size requirement $spaceRequired is greater than page size $PAGE_SIZE") + } else if (spaceRequired > (PAGE_SIZE - currentPagePosition)) { + currentPage = memoryManager.allocatePage(PAGE_SIZE) + allocatedPages.add(currentPage) + currentPagePosition = currentPage.getBaseOffset + } + } + + // TODO: the size of this buffer should be configurable val serArray = new Array[Byte](1024 * 1024) val byteBuffer = ByteBuffer.wrap(serArray) val bbos = new ByteBufferOutputStream() bbos.setByteBuffer(byteBuffer) val serBufferSerStream = serializer.serializeStream(bbos) - while (records.hasNext) { - val nextRecord: Product2[K, V] = records.next() - println("Writing record " + nextRecord) - val partitionId: Int = partitioner.getPartition(nextRecord._1) - serBufferSerStream.writeObject(nextRecord) - - val sizeRequirement: Int = byteBuffer.position() + 8 + 8 - println("Size requirement in intenral buffer is " + sizeRequirement) - if (sizeRequirement > (PAGE_SIZE - currentPagePosition)) { - println("Allocating a new data page after writing " + currentPagePosition) - currentPage = memoryManager.allocatePage(PAGE_SIZE) - allocatedPages.add(currentPage) - currentPagePosition = currentPage.getBaseOffset - } - println("Before writing record, current page position is " + currentPagePosition) - // TODO: check that it's still not too large + def writeRecord(record: Product2[Any, Any]): Unit = { + val (key, value) = record + val partitionId = partitioner.getPartition(key) + serBufferSerStream.writeKey(key) + serBufferSerStream.writeValue(value) + serBufferSerStream.flush() + + val serializedRecordSize = byteBuffer.position() + // TODO: we should run the partition extraction function _now_, at insert time, rather than + // requiring it to be stored alongisde the data, since this may lead to double storage + val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8 + ensureSpaceInDataPage(sizeRequirementInSortDataPage) + val newRecordAddress = memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition) PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId) currentPagePosition += 8 + println("The stored record length is " + byteBuffer.position()) PlatformDependent.UNSAFE.putLong( currentPage.getBaseObject, currentPagePosition, byteBuffer.position()) currentPagePosition += 8 @@ -162,45 +165,53 @@ private[spark] class UnsafeShuffleWriter[K, V]( currentPagePosition += byteBuffer.position() println("After writing record, current page position is " + currentPagePosition) sorter.insertRecord(newRecordAddress) + + // Reset for writing the next record byteBuffer.position(0) } - // TODO: free the buffers, etc, at this point since they're not needed - val sortedIterator: util.Iterator[KeyPointerAndPrefix] = sorter.getSortedIterator - // Now that the partition is sorted, write out the data to a file, keeping track off offsets - // for use in the sort-based shuffle index. + + while (records.hasNext) { + writeRecord(records.next()) + } + + sorter.getSortedIterator + } + + private def writeSortedRecordsToFile(sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = { val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = new Array[Long](partitioner.numPartitions) - // TODO: compression tests? - // TODO why is append true here? - // TODO: metrics tracking and all of the other stuff that diskblockobjectwriter would give us - // TODO: note that we saw FAILED_TO_UNCOMPRESS(5) at some points during debugging when we were - // not properly wrapping the writer for compression even though readers expected compressed - // data; the fact that someone still reported this issue in newer Spark versions suggests that - // we should audit the code to make sure wrapping is done at the right set of places and to - // check that we haven't missed any rare corner-cases / rarely-used paths. - val out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) - val serOut = serializer.serializeStream(out) - serOut.flush() + var currentPartition = -1 - var currentPartitionLength: Long = 0 - while (sortedIterator.hasNext) { - val keyPointerAndPrefix: KeyPointerAndPrefix = sortedIterator.next() - val partition = keyPointerAndPrefix.keyPrefix.toInt - println("Partition is " + partition) - if (currentPartition == -1) { - currentPartition = partition + var prevPartitionLength: Long = 0 + var out: OutputStream = null + + // TODO: don't close and re-open file handles so often; this could be inefficient + + def closePartition(): Unit = { + out.flush() + out.close() + partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength + } + + def switchToPartition(newPartition: Int): Unit = { + if (currentPartition != -1) { + closePartition() + prevPartitionLength = partitionLengths(currentPartition) } + currentPartition = newPartition + out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) + } + + while (sortedRecords.hasNext) { + val keyPointerAndPrefix: KeyPointerAndPrefix = sortedRecords.next() + val partition = keyPointerAndPrefix.keyPrefix.toInt if (partition != currentPartition) { - println("switching partition") - partitionLengths(currentPartition) = currentPartitionLength - currentPartitionLength = 0 - currentPartition = partition + switchToPartition(partition) } val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer) val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer) val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8) - partitionLengths(currentPartition) += recordLength println("Base offset is " + baseOffset) println("Record length is " + recordLength) var i: Int = 0 @@ -213,10 +224,19 @@ private[spark] class UnsafeShuffleWriter[K, V]( i += 1 } } - out.flush() - //serOut.close() - //out.flush() - out.close() + closePartition() + + partitionLengths + } + + /** Write a sequence of records to this task's output */ + override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + println("Opened writer!") + + val sortedIterator = sortRecords(records) + val partitionLengths = writeSortedRecordsToFile(sortedIterator) + + println("Partition lengths are " + partitionLengths.toSeq) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } @@ -239,7 +259,7 @@ private[spark] class UnsafeShuffleWriter[K, V]( } } finally { // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { + if (!allocatedPages.isEmpty) { val iter = allocatedPages.iterator() while (iter.hasNext) { memoryManager.freePage(iter.next()) @@ -249,7 +269,6 @@ private[spark] class UnsafeShuffleWriter[K, V]( //sorter.stop() context.taskMetrics().shuffleWriteMetrics.foreach( _.incShuffleWriteTime(System.nanoTime - startTime)) - sorter = null } } } From 253f13ee0796aa724decd075d159b81eda459daf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 13:55:35 -0700 Subject: [PATCH 09/92] More cleanup --- .../shuffle/unsafe/UnsafeShuffleManager.scala | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 0e796dfe2aefd..04ac811ac7966 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -115,8 +115,8 @@ private[spark] class UnsafeShuffleWriter[K, V]( val serializer = Serializer.getSerializer(dep.serializer).newInstance() val PAGE_SIZE = 1024 * 1024 * 1 - var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE) - var currentPagePosition: Long = currentPage.getBaseOffset + var currentPage: MemoryBlock = null + var currentPagePosition: Long = PAGE_SIZE def ensureSpaceInDataPage(spaceRequired: Long): Unit = { if (spaceRequired > PAGE_SIZE) { @@ -143,6 +143,7 @@ private[spark] class UnsafeShuffleWriter[K, V]( serBufferSerStream.flush() val serializedRecordSize = byteBuffer.position() + assert(serializedRecordSize > 0) // TODO: we should run the partition extraction function _now_, at insert time, rather than // requiring it to be stored alongisde the data, since this may lead to double storage val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8 @@ -152,17 +153,17 @@ private[spark] class UnsafeShuffleWriter[K, V]( memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition) PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId) currentPagePosition += 8 - println("The stored record length is " + byteBuffer.position()) + println("The stored record length is " + serializedRecordSize) PlatformDependent.UNSAFE.putLong( - currentPage.getBaseObject, currentPagePosition, byteBuffer.position()) + currentPage.getBaseObject, currentPagePosition, serializedRecordSize) currentPagePosition += 8 PlatformDependent.copyMemory( serArray, PlatformDependent.BYTE_ARRAY_OFFSET, currentPage.getBaseObject, currentPagePosition, - byteBuffer.position()) - currentPagePosition += byteBuffer.position() + serializedRecordSize) + currentPagePosition += serializedRecordSize println("After writing record, current page position is " + currentPagePosition) sorter.insertRecord(newRecordAddress) @@ -195,10 +196,12 @@ private[spark] class UnsafeShuffleWriter[K, V]( } def switchToPartition(newPartition: Int): Unit = { + assert (newPartition > currentPartition, s"new partition $newPartition should be >= $currentPartition") if (currentPartition != -1) { closePartition() prevPartitionLength = partitionLengths(currentPartition) } + println(s"Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq) currentPartition = newPartition out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) } @@ -214,11 +217,11 @@ private[spark] class UnsafeShuffleWriter[K, V]( val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8) println("Base offset is " + baseOffset) println("Record length is " + recordLength) - var i: Int = 0 // TODO: need to have a way to figure out whether a serializer supports relocation of // serialized objects or not. Sandy also ran into this in his patch (see // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might // as well just bypass this optimized code path in favor of the old one. + var i: Int = 0 while (i < recordLength) { out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i)) i += 1 @@ -241,6 +244,14 @@ private[spark] class UnsafeShuffleWriter[K, V]( mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } + private def freeMemory(): Unit = { + val iter = allocatedPages.iterator() + while (iter.hasNext) { + memoryManager.freePage(iter.next()) + iter.remove() + } + } + /** Close this writer, passing along whether the map completed */ override def stop(success: Boolean): Option[MapStatus] = { println("Stopping unsafeshufflewriter") @@ -249,6 +260,7 @@ private[spark] class UnsafeShuffleWriter[K, V]( None } else { stopping = true + freeMemory() if (success) { Option(mapStatus) } else { @@ -258,24 +270,14 @@ private[spark] class UnsafeShuffleWriter[K, V]( } } } finally { - // Clean up our sorter, which may have its own intermediate files - if (!allocatedPages.isEmpty) { - val iter = allocatedPages.iterator() - while (iter.hasNext) { - memoryManager.freePage(iter.next()) - iter.remove() - } - val startTime = System.nanoTime() - //sorter.stop() - context.taskMetrics().shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - startTime)) - } + freeMemory() + val startTime = System.nanoTime() + context.taskMetrics().shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - startTime)) } } } - - private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager { private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) From 9c6cf58e1569dcf2a193b4af2e822a649a9d7775 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 14:11:32 -0700 Subject: [PATCH 10/92] Refactor to use DiskBlockObjectWriter. --- .../shuffle/unsafe/UnsafeShuffleManager.scala | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 04ac811ac7966..fe092683d5400 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.unsafe -import java.io.{FileOutputStream, OutputStream} import java.nio.ByteBuffer import java.util @@ -29,7 +28,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.{BlockObjectWriter, ShuffleBlockId} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} import org.apache.spark.unsafe.sort.UnsafeSorter @@ -104,7 +103,14 @@ private[spark] class UnsafeShuffleWriter[K, V]( private[this] val blockManager = SparkEnv.get.blockManager - private def sortRecords(records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private[this] val fileBufferSize = + SparkEnv.get.conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + + private[this] val serializer = Serializer.getSerializer(dep.serializer).newInstance() + + private def sortRecords( + records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = { val sorter = new UnsafeSorter( context.taskMemoryManager(), DummyRecordComparator, @@ -112,7 +118,6 @@ private[spark] class UnsafeShuffleWriter[K, V]( PartitionerPrefixComparator, 4096 // initial size ) - val serializer = Serializer.getSerializer(dep.serializer).newInstance() val PAGE_SIZE = 1024 * 1024 * 1 var currentPage: MemoryBlock = null @@ -178,32 +183,31 @@ private[spark] class UnsafeShuffleWriter[K, V]( sorter.getSortedIterator } - private def writeSortedRecordsToFile(sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = { + private def writeSortedRecordsToFile( + sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = { val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = new Array[Long](partitioner.numPartitions) var currentPartition = -1 - var prevPartitionLength: Long = 0 - var out: OutputStream = null + var writer: BlockObjectWriter = null // TODO: don't close and re-open file handles so often; this could be inefficient def closePartition(): Unit = { - out.flush() - out.close() - partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength + writer.commitAndClose() + partitionLengths(currentPartition) = writer.fileSegment().length } def switchToPartition(newPartition: Int): Unit = { - assert (newPartition > currentPartition, s"new partition $newPartition should be >= $currentPartition") + assert (newPartition > currentPartition, + s"new partition $newPartition should be >= $currentPartition") if (currentPartition != -1) { closePartition() - prevPartitionLength = partitionLengths(currentPartition) } - println(s"Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq) currentPartition = newPartition - out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) + writer = + blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics) } while (sortedRecords.hasNext) { @@ -214,18 +218,24 @@ private[spark] class UnsafeShuffleWriter[K, V]( } val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer) val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer) - val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8) + val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt println("Base offset is " + baseOffset) println("Record length is " + recordLength) // TODO: need to have a way to figure out whether a serializer supports relocation of // serialized objects or not. Sandy also ran into this in his patch (see // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might // as well just bypass this optimized code path in favor of the old one. - var i: Int = 0 - while (i < recordLength) { - out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i)) - i += 1 - } + // TODO: re-use a buffer or avoid double-buffering entirely + val arr: Array[Byte] = new Array[Byte](recordLength) + PlatformDependent.copyMemory( + baseObject, + baseOffset + 16, + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength) + writer.write(arr) + // TODO: add a test that detects whether we leave this call out: + writer.recordWritten() } closePartition() From e267cee305a275e7e50fe13a08f5f22b2f9d5939 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 14:13:21 -0700 Subject: [PATCH 11/92] Fix compilation of UnsafeSorterSuite --- .../java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index 52bb27a5b572f..4c3b982747693 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -36,7 +36,7 @@ public class UnsafeSorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset) { final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); final byte[] strBytes = new byte[strLength]; - PlatformDependent.UNSAFE.copyMemory( + PlatformDependent.copyMemory( baseObject, baseOffset + 8, strBytes, From e2d96ca59b74c2aa004c471b651c7de2acaca51f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 14:43:56 -0700 Subject: [PATCH 12/92] Expand serializer API and use new function to help control when new UnsafeShuffle path is used. --- .../spark/serializer/KryoSerializer.scala | 5 +++ .../apache/spark/serializer/Serializer.scala | 26 ++++++++++- .../shuffle/unsafe/UnsafeShuffleManager.scala | 44 ++++++++++--------- .../util/collection/ExternalSorter.scala | 3 +- 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index b7bc087855b9f..f6c17e362f9b3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -125,6 +125,11 @@ class KryoSerializer(conf: SparkConf) override def newInstance(): SerializerInstance = { new KryoSerializerInstance(this) } + + override def supportsRelocationOfSerializedObjects: Boolean = { + // TODO: we should have a citation / explanatory comment here clarifying _why_ this is the case + newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset() + } } private[spark] diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index c381672a4f588..144a1c51ac858 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} /** @@ -63,6 +63,30 @@ abstract class Serializer { /** Creates a new [[SerializerInstance]]. */ def newInstance(): SerializerInstance + + /** + * Returns true if this serializer supports relocation of its serialized objects and false + * otherwise. This should return true if and only if reordering the bytes of serialized objects + * in serialization stream output results in re-ordered input that can be read with the + * deserializer. For instance, the following should work if the serializer supports relocation: + * + * serOut.open() + * position = 0 + * serOut.write(obj1) + * serOut.flush() + * position = # of bytes writen to stream so far + * obj1Bytes = [bytes 0 through position of stream] + * serOut.write(obj2) + * serOut.flush + * position2 = # of bytes written to stream so far + * obj2Bytes = bytes[position through position2 of stream] + * + * serIn.open([obj2bytes] concatenate [obj1bytes]) should return (obj2, obj1) + * + * See SPARK-7311 for more discussion. + */ + @Experimental + def supportsRelocationOfSerializedObjects: Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index fe092683d5400..489bcf42cb448 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -22,7 +22,7 @@ import java.util import com.esotericsoftware.kryo.io.ByteBufferOutputStream -import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer @@ -34,17 +34,31 @@ import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} import org.apache.spark.unsafe.sort.UnsafeSorter import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator} -private[spark] class UnsafeShuffleHandle[K, V]( +private class UnsafeShuffleHandle[K, V]( shuffleId: Int, override val numMaps: Int, override val dependency: ShuffleDependency[K, V, V]) extends BaseShuffleHandle(shuffleId, numMaps, dependency) { - require(UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) } -private[spark] object UnsafeShuffleManager { +private[spark] object UnsafeShuffleManager extends Logging { def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { - dependency.aggregator.isEmpty && dependency.keyOrdering.isEmpty + val shufId = dependency.shuffleId + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") + false + } else if (dependency.keyOrdering.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined") + false + } else { + log.debug(s"Can use UnsafeShuffle for shuffle $shufId") + true + } } } @@ -73,15 +87,13 @@ private object PartitionerPrefixComparator extends PrefixComparator { } } -private[spark] class UnsafeShuffleWriter[K, V]( +private class UnsafeShuffleWriter[K, V]( shuffleBlockManager: IndexShuffleBlockManager, handle: UnsafeShuffleHandle[K, V], mapId: Int, context: TaskContext) extends ShuffleWriter[K, V] { - println("Construcing a new UnsafeShuffleWriter") - private[this] val memoryManager: TaskMemoryManager = context.taskMemoryManager() private[this] val dep = handle.dependency @@ -158,7 +170,6 @@ private[spark] class UnsafeShuffleWriter[K, V]( memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition) PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId) currentPagePosition += 8 - println("The stored record length is " + serializedRecordSize) PlatformDependent.UNSAFE.putLong( currentPage.getBaseObject, currentPagePosition, serializedRecordSize) currentPagePosition += 8 @@ -169,7 +180,6 @@ private[spark] class UnsafeShuffleWriter[K, V]( currentPagePosition, serializedRecordSize) currentPagePosition += serializedRecordSize - println("After writing record, current page position is " + currentPagePosition) sorter.insertRecord(newRecordAddress) // Reset for writing the next record @@ -195,8 +205,10 @@ private[spark] class UnsafeShuffleWriter[K, V]( // TODO: don't close and re-open file handles so often; this could be inefficient def closePartition(): Unit = { - writer.commitAndClose() - partitionLengths(currentPartition) = writer.fileSegment().length + if (writer != null) { + writer.commitAndClose() + partitionLengths(currentPartition) = writer.fileSegment().length + } } def switchToPartition(newPartition: Int): Unit = { @@ -219,8 +231,6 @@ private[spark] class UnsafeShuffleWriter[K, V]( val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer) val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer) val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt - println("Base offset is " + baseOffset) - println("Record length is " + recordLength) // TODO: need to have a way to figure out whether a serializer supports relocation of // serialized objects or not. Sandy also ran into this in his patch (see // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might @@ -244,12 +254,8 @@ private[spark] class UnsafeShuffleWriter[K, V]( /** Write a sequence of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - println("Opened writer!") - val sortedIterator = sortRecords(records) val partitionLengths = writeSortedRecordsToFile(sortedIterator) - - println("Partition lengths are " + partitionLengths.toSeq) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } @@ -264,7 +270,6 @@ private[spark] class UnsafeShuffleWriter[K, V]( /** Close this writer, passing along whether the map completed */ override def stop(success: Boolean): Option[MapStatus] = { - println("Stopping unsafeshufflewriter") try { if (stopping) { None @@ -300,7 +305,6 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { - println("Opening unsafeShuffleWriter") new UnsafeShuffleHandle[K, V]( shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b7306cd551918..7d5cf7b61e56a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -131,8 +131,7 @@ private[spark] class ExternalSorter[K, V, C]( private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB private val useSerializedPairBuffer = !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && - ser.isInstanceOf[KryoSerializer] && - serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset + ser.supportsRelocationOfSerializedObjects // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we From d3cc310de0e35057d316d201ed7ee5498b1eca9c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 14:47:33 -0700 Subject: [PATCH 13/92] Flag that SparkSqlSerializer2 supports relocation --- core/src/test/resources/log4j.properties | 2 +- .../org/apache/spark/sql/execution/SparkSqlSerializer2.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index eb3b1999eb996..9512ac1ac79c3 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file +log4j.rootCategory=DEBUG, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 9552f41115866..c841362a246ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -154,6 +154,8 @@ private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: with Serializable{ def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) + + override def supportsRelocationOfSerializedObjects: Boolean = true } private[sql] object SparkSqlSerializer2 { From 87e721b7501ba6f96db919384b52e90c7a8c8d91 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 16:34:47 -0700 Subject: [PATCH 14/92] Renaming and comments --- .../unsafe/sort/UnsafeSortDataFormat.java | 15 ++--- .../spark/unsafe/sort/UnsafeSorter.java | 65 ++++++++++++++----- .../shuffle/unsafe/UnsafeShuffleManager.scala | 8 +-- .../spark/unsafe/sort/UnsafeSorterSuite.java | 6 +- 4 files changed, 64 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java index 9955e3fcaabbb..290a87b70cad6 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java @@ -17,35 +17,34 @@ package org.apache.spark.unsafe.sort; -import static org.apache.spark.unsafe.sort.UnsafeSorter.KeyPointerAndPrefix; +import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordPointerAndKeyPrefix; import org.apache.spark.util.collection.SortDataFormat; /** - * TODO: finish writing this description + * Supports sorting an array of (record pointer, key prefix) pairs. Used in {@link UnsafeSorter}. * * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ -final class UnsafeSortDataFormat - extends SortDataFormat { +final class UnsafeSortDataFormat extends SortDataFormat { public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); private UnsafeSortDataFormat() { } @Override - public KeyPointerAndPrefix getKey(long[] data, int pos) { + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @Override - public KeyPointerAndPrefix newKey() { - return new KeyPointerAndPrefix(); + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); } @Override - public KeyPointerAndPrefix getKey(long[] data, int pos, KeyPointerAndPrefix reuse) { + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { reuse.recordPointer = data[pos * 2]; reuse.keyPrefix = data[pos * 2 + 1]; return reuse; diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index d33ca321a9835..7795ee6a5f0e2 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -23,9 +23,16 @@ import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ public final class UnsafeSorter { - public static final class KeyPointerAndPrefix { + public static final class RecordPointerAndKeyPrefix { /** * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a * description of how these addresses are encoded. @@ -37,6 +44,7 @@ public static final class KeyPointerAndPrefix { */ public long keyPrefix; + // TODO: this was a carryover from test code; may want to remove this @Override public int hashCode() { throw new UnsupportedOperationException(); @@ -48,7 +56,17 @@ public boolean equals(Object obj) { } } + /** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ public static abstract class RecordComparator { + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ public abstract int compare( Object leftBaseObject, long leftBaseOffset, @@ -56,13 +74,16 @@ public abstract int compare( long rightBaseOffset); } + /** + * Given a pointer to a record, computes a prefix. + */ public static abstract class PrefixComputer { public abstract long computePrefix(Object baseObject, long baseOffset); } /** - * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific comparisons, - * such as lexicographic comparison for strings. + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. */ public static abstract class PrefixComparator { public abstract int compare(long prefix1, long prefix2); @@ -70,8 +91,8 @@ public static abstract class PrefixComparator { private final TaskMemoryManager memoryManager; private final PrefixComputer prefixComputer; - private final Sorter sorter; - private final Comparator sortComparator; + private final Sorter sorter; + private final Comparator sortComparator; /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at @@ -79,8 +100,12 @@ public static abstract class PrefixComparator { */ private long[] sortBuffer; + /** + * The position in the sort buffer where new records can be inserted. + */ private int sortBufferInsertPosition = 0; + private void expandSortBuffer(int newSize) { assert (newSize > sortBuffer.length); final long[] oldBuffer = sortBuffer; @@ -99,11 +124,13 @@ public UnsafeSorter( this.memoryManager = memoryManager; this.prefixComputer = prefixComputer; this.sorter = - new Sorter(UnsafeSortDataFormat.INSTANCE); - this.sortComparator = new Comparator() { + new Sorter(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new Comparator() { @Override - public int compare(KeyPointerAndPrefix left, KeyPointerAndPrefix right) { - if (left.keyPrefix == right.keyPrefix) { + public int compare(RecordPointerAndKeyPrefix left, RecordPointerAndKeyPrefix right) { + final int prefixComparisonResult = + prefixComparator.compare(left.keyPrefix, right.keyPrefix); + if (prefixComparisonResult == 0) { final Object leftBaseObject = memoryManager.getPage(left.recordPointer); final long leftBaseOffset = memoryManager.getOffsetInPage(left.recordPointer); final Object rightBaseObject = memoryManager.getPage(right.recordPointer); @@ -111,12 +138,17 @@ public int compare(KeyPointerAndPrefix left, KeyPointerAndPrefix right) { return recordComparator.compare( leftBaseObject, leftBaseOffset, rightBaseObject, rightBaseOffset); } else { - return prefixComparator.compare(left.keyPrefix, right.keyPrefix); + return prefixComparisonResult; } } }; } + /** + * Insert a record into the sort buffer. + * + * @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + */ public void insertRecord(long objectAddress) { if (sortBufferInsertPosition + 2 == sortBuffer.length) { expandSortBuffer(sortBuffer.length * 2); @@ -130,11 +162,15 @@ public void insertRecord(long objectAddress) { sortBufferInsertPosition++; } - public Iterator getSortedIterator() { + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public Iterator getSortedIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new Iterator() { + return new Iterator() { private int position = 0; - private final KeyPointerAndPrefix keyPointerAndPrefix = new KeyPointerAndPrefix(); + private final RecordPointerAndKeyPrefix keyPointerAndPrefix = new RecordPointerAndKeyPrefix(); @Override public boolean hasNext() { @@ -142,7 +178,7 @@ public boolean hasNext() { } @Override - public KeyPointerAndPrefix next() { + public RecordPointerAndKeyPrefix next() { keyPointerAndPrefix.recordPointer = sortBuffer[position]; keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1]; position += 2; @@ -155,5 +191,4 @@ public void remove() { } }; } - } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 489bcf42cb448..7eb825641263b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.storage.{BlockObjectWriter, ShuffleBlockId} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} import org.apache.spark.unsafe.sort.UnsafeSorter -import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator} +import org.apache.spark.unsafe.sort.UnsafeSorter.{RecordPointerAndKeyPrefix, PrefixComparator, PrefixComputer, RecordComparator} private class UnsafeShuffleHandle[K, V]( shuffleId: Int, @@ -122,7 +122,7 @@ private class UnsafeShuffleWriter[K, V]( private[this] val serializer = Serializer.getSerializer(dep.serializer).newInstance() private def sortRecords( - records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = { + records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[RecordPointerAndKeyPrefix] = { val sorter = new UnsafeSorter( context.taskMemoryManager(), DummyRecordComparator, @@ -194,7 +194,7 @@ private class UnsafeShuffleWriter[K, V]( } private def writeSortedRecordsToFile( - sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = { + sortedRecords: java.util.Iterator[RecordPointerAndKeyPrefix]): Array[Long] = { val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = new Array[Long](partitioner.numPartitions) @@ -223,7 +223,7 @@ private class UnsafeShuffleWriter[K, V]( } while (sortedRecords.hasNext) { - val keyPointerAndPrefix: KeyPointerAndPrefix = sortedRecords.next() + val keyPointerAndPrefix: RecordPointerAndKeyPrefix = sortedRecords.next() val partition = keyPointerAndPrefix.keyPrefix.toInt if (partition != currentPartition) { switchToPartition(partition) diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index 4c3b982747693..2f88df1210bbc 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -52,7 +52,7 @@ public void testSortingEmptyInput() { mock(UnsafeSorter.PrefixComputer.class), mock(UnsafeSorter.PrefixComparator.class), 100); - final Iterator iter = sorter.getSortedIterator(); + final Iterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -130,12 +130,12 @@ public int compare(long prefix1, long prefix2) { sorter.insertRecord(address); position += 8 + recordLength; } - final Iterator iter = sorter.getSortedIterator(); + final Iterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; Arrays.sort(dataToSort); while (iter.hasNext()) { - final UnsafeSorter.KeyPointerAndPrefix pointerAndPrefix = iter.next(); + final UnsafeSorter.RecordPointerAndKeyPrefix pointerAndPrefix = iter.next(); final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); From 07484589af43de03343fe58601793f6aaff33d56 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 17:50:48 -0700 Subject: [PATCH 15/92] Port UnsafeShuffleWriter to Java. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 282 ++++++++++++++++++ .../shuffle/unsafe/UnsafeShuffleManager.scala | 243 --------------- 2 files changed, 282 insertions(+), 243 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java new file mode 100644 index 0000000000000..bf368d4a11526 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import scala.Option; +import scala.Product2; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.LinkedList; + +import com.esotericsoftware.kryo.io.ByteBufferOutputStream; + +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockManager; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.ShuffleBlockId; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.unsafe.sort.UnsafeSorter; +import static org.apache.spark.unsafe.sort.UnsafeSorter.*; + +// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles +public class UnsafeShuffleWriter implements ShuffleWriter { + + private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + private final IndexShuffleBlockManager shuffleBlockManager; + private final BlockManager blockManager = SparkEnv.get().blockManager(); + private final int shuffleId; + private final int mapId; + private final TaskMemoryManager memoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final LinkedList allocatedPages = new LinkedList(); + private final int fileBufferSize; + private MapStatus mapStatus = null; + + private MemoryBlock currentPage = null; + private long currentPagePosition = PAGE_SIZE; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public UnsafeShuffleWriter( + IndexShuffleBlockManager shuffleBlockManager, + UnsafeShuffleHandle handle, + int mapId, + TaskContext context) { + this.shuffleBlockManager = shuffleBlockManager; + this.mapId = mapId; + this.memoryManager = context.taskMemoryManager(); + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); + this.partitioner = dep.partitioner(); + this.writeMetrics = new ShuffleWriteMetrics(); + context.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.fileBufferSize = + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + (int) SparkEnv.get().conf().getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + } + + public void write(scala.collection.Iterator> records) { + try { + final long[] partitionLengths = writeSortedRecordsToFile(sortRecords(records)); + shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } catch (Exception e) { + PlatformDependent.throwException(e); + } + } + + private void ensureSpaceInDataPage(long requiredSpace) throws Exception { + if (requiredSpace > PAGE_SIZE) { + // TODO: throw a more specific exception? + throw new Exception("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else if (requiredSpace > (PAGE_SIZE - currentPagePosition)) { + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + } + + private void freeMemory() { + final Iterator iter = allocatedPages.iterator(); + while (iter.hasNext()) { + memoryManager.freePage(iter.next()); + iter.remove(); + } + } + + private Iterator sortRecords( + scala.collection.Iterator> records) throws Exception { + final UnsafeSorter sorter = new UnsafeSorter( + memoryManager, + RECORD_COMPARATOR, + PREFIX_COMPUTER, + PREFIX_COMPARATOR, + 4096 // Initial size (TODO: tune this!) + ); + + final byte[] serArray = new byte[SER_BUFFER_SIZE]; + final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray); + // TODO: we should not depend on this class from Kryo; copy its source or find an alternative + final SerializationStream serOutputStream = + serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serByteBuffer.position(0); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serByteBuffer.position(); + assert (serializedRecordSize > 0); + // TODO: we should run the partition extraction function _now_, at insert time, rather than + // requiring it to be stored alongisde the data, since this may lead to double storage + // Need 8 bytes to store the prefix (for later retrieval in the prefix computer), plus + // 8 to store the record length (TODO: can store as an int instead). + ensureSpaceInDataPage(serializedRecordSize + 8 + 8); + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object baseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, partitionId); + currentPagePosition += 8; + PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, serializedRecordSize); + currentPagePosition += 8; + PlatformDependent.copyMemory( + serArray, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + currentPagePosition, + serializedRecordSize); + currentPagePosition += serializedRecordSize; + + sorter.insertRecord(recordAddress); + } + + return sorter.getSortedIterator(); + } + + private long[] writeSortedRecordsToFile( + Iterator sortedRecords) throws IOException { + final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); + final ShuffleBlockId blockId = + new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); + final long[] partitionLengths = new long[partitioner.numPartitions()]; + + int currentPartition = -1; + BlockObjectWriter writer = null; + + while (sortedRecords.hasNext()) { + final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); + final int partition = (int) recordPointer.keyPrefix; + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + writer.commitAndClose(); + partitionLengths[currentPartition] = writer.fileSegment().length(); + } + currentPartition = partition; + writer = + blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics); + } + + final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); + final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); + final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8); + // TODO: re-use a buffer or avoid double-buffering entirely + final byte[] arr = new byte[recordLength]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + 16, + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength); + assert (writer != null); // To suppress an IntelliJ warning + writer.write(arr); + // TODO: add a test that detects whether we leave this call out: + writer.recordWritten(); + } + + if (writer != null) { + writer.commitAndClose(); + partitionLengths[currentPartition] = writer.fileSegment().length(); + } + + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + try { + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + freeMemory(); + if (success) { + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + shuffleBlockManager.removeDataByMap(shuffleId, mapId); + return Option.apply(null); + } + } + } finally { + freeMemory(); + // TODO: increment the shuffle write time metrics + } + } + + private static final RecordComparator RECORD_COMPARATOR = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, long leftBaseOffset, Object rightBaseObject, long rightBaseOffset) { + return 0; + } + }; + + private static final PrefixComputer PREFIX_COMPUTER = new PrefixComputer() { + @Override + public long computePrefix(Object baseObject, long baseOffset) { + // TODO: should the prefix be computed when inserting the record pointer rather than being + // read from the record itself? May be more efficient in terms of space, etc, and is a simple + // change. + return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + } + }; + + private static final PrefixComparator PREFIX_COMPARATOR = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) (prefix1 - prefix2); + } + }; +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 7eb825641263b..0dd34b372f624 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -17,22 +17,10 @@ package org.apache.spark.shuffle.unsafe -import java.nio.ByteBuffer -import java.util - -import com.esotericsoftware.kryo.io.ByteBufferOutputStream - import org.apache.spark._ -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.{BlockObjectWriter, ShuffleBlockId} -import org.apache.spark.unsafe.PlatformDependent -import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} -import org.apache.spark.unsafe.sort.UnsafeSorter -import org.apache.spark.unsafe.sort.UnsafeSorter.{RecordPointerAndKeyPrefix, PrefixComparator, PrefixComputer, RecordComparator} private class UnsafeShuffleHandle[K, V]( shuffleId: Int, @@ -62,237 +50,6 @@ private[spark] object UnsafeShuffleManager extends Logging { } } -private object DummyRecordComparator extends RecordComparator { - override def compare( - leftBaseObject: scala.Any, - leftBaseOffset: Long, - rightBaseObject: scala.Any, - rightBaseOffset: Long): Int = { - 0 - } -} - -private object PartitionerPrefixComputer extends PrefixComputer { - override def computePrefix(baseObject: scala.Any, baseOffset: Long): Long = { - // TODO: should the prefix be computed when inserting the record pointer rather than being - // read from the record itself? May be more efficient in terms of space, etc, and is a simple - // change. - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset) - } -} - -private object PartitionerPrefixComparator extends PrefixComparator { - override def compare(prefix1: Long, prefix2: Long): Int = { - (prefix1 - prefix2).toInt - } -} - -private class UnsafeShuffleWriter[K, V]( - shuffleBlockManager: IndexShuffleBlockManager, - handle: UnsafeShuffleHandle[K, V], - mapId: Int, - context: TaskContext) - extends ShuffleWriter[K, V] { - - private[this] val memoryManager: TaskMemoryManager = context.taskMemoryManager() - - private[this] val dep = handle.dependency - - private[this] val partitioner = dep.partitioner - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private[this] var stopping = false - - private[this] var mapStatus: MapStatus = null - - private[this] val writeMetrics = new ShuffleWriteMetrics() - context.taskMetrics().shuffleWriteMetrics = Some(writeMetrics) - - private[this] val allocatedPages: util.LinkedList[MemoryBlock] = - new util.LinkedList[MemoryBlock]() - - private[this] val blockManager = SparkEnv.get.blockManager - - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided - private[this] val fileBufferSize = - SparkEnv.get.conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - - private[this] val serializer = Serializer.getSerializer(dep.serializer).newInstance() - - private def sortRecords( - records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[RecordPointerAndKeyPrefix] = { - val sorter = new UnsafeSorter( - context.taskMemoryManager(), - DummyRecordComparator, - PartitionerPrefixComputer, - PartitionerPrefixComparator, - 4096 // initial size - ) - val PAGE_SIZE = 1024 * 1024 * 1 - - var currentPage: MemoryBlock = null - var currentPagePosition: Long = PAGE_SIZE - - def ensureSpaceInDataPage(spaceRequired: Long): Unit = { - if (spaceRequired > PAGE_SIZE) { - throw new Exception(s"Size requirement $spaceRequired is greater than page size $PAGE_SIZE") - } else if (spaceRequired > (PAGE_SIZE - currentPagePosition)) { - currentPage = memoryManager.allocatePage(PAGE_SIZE) - allocatedPages.add(currentPage) - currentPagePosition = currentPage.getBaseOffset - } - } - - // TODO: the size of this buffer should be configurable - val serArray = new Array[Byte](1024 * 1024) - val byteBuffer = ByteBuffer.wrap(serArray) - val bbos = new ByteBufferOutputStream() - bbos.setByteBuffer(byteBuffer) - val serBufferSerStream = serializer.serializeStream(bbos) - - def writeRecord(record: Product2[Any, Any]): Unit = { - val (key, value) = record - val partitionId = partitioner.getPartition(key) - serBufferSerStream.writeKey(key) - serBufferSerStream.writeValue(value) - serBufferSerStream.flush() - - val serializedRecordSize = byteBuffer.position() - assert(serializedRecordSize > 0) - // TODO: we should run the partition extraction function _now_, at insert time, rather than - // requiring it to be stored alongisde the data, since this may lead to double storage - val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8 - ensureSpaceInDataPage(sizeRequirementInSortDataPage) - - val newRecordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition) - PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId) - currentPagePosition += 8 - PlatformDependent.UNSAFE.putLong( - currentPage.getBaseObject, currentPagePosition, serializedRecordSize) - currentPagePosition += 8 - PlatformDependent.copyMemory( - serArray, - PlatformDependent.BYTE_ARRAY_OFFSET, - currentPage.getBaseObject, - currentPagePosition, - serializedRecordSize) - currentPagePosition += serializedRecordSize - sorter.insertRecord(newRecordAddress) - - // Reset for writing the next record - byteBuffer.position(0) - } - - while (records.hasNext) { - writeRecord(records.next()) - } - - sorter.getSortedIterator - } - - private def writeSortedRecordsToFile( - sortedRecords: java.util.Iterator[RecordPointerAndKeyPrefix]): Array[Long] = { - val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) - val partitionLengths = new Array[Long](partitioner.numPartitions) - - var currentPartition = -1 - var writer: BlockObjectWriter = null - - // TODO: don't close and re-open file handles so often; this could be inefficient - - def closePartition(): Unit = { - if (writer != null) { - writer.commitAndClose() - partitionLengths(currentPartition) = writer.fileSegment().length - } - } - - def switchToPartition(newPartition: Int): Unit = { - assert (newPartition > currentPartition, - s"new partition $newPartition should be >= $currentPartition") - if (currentPartition != -1) { - closePartition() - } - currentPartition = newPartition - writer = - blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics) - } - - while (sortedRecords.hasNext) { - val keyPointerAndPrefix: RecordPointerAndKeyPrefix = sortedRecords.next() - val partition = keyPointerAndPrefix.keyPrefix.toInt - if (partition != currentPartition) { - switchToPartition(partition) - } - val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer) - val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer) - val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt - // TODO: need to have a way to figure out whether a serializer supports relocation of - // serialized objects or not. Sandy also ran into this in his patch (see - // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might - // as well just bypass this optimized code path in favor of the old one. - // TODO: re-use a buffer or avoid double-buffering entirely - val arr: Array[Byte] = new Array[Byte](recordLength) - PlatformDependent.copyMemory( - baseObject, - baseOffset + 16, - arr, - PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength) - writer.write(arr) - // TODO: add a test that detects whether we leave this call out: - writer.recordWritten() - } - closePartition() - - partitionLengths - } - - /** Write a sequence of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - val sortedIterator = sortRecords(records) - val partitionLengths = writeSortedRecordsToFile(sortedIterator) - shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) - } - - private def freeMemory(): Unit = { - val iter = allocatedPages.iterator() - while (iter.hasNext) { - memoryManager.freePage(iter.next()) - iter.remove() - } - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - None - } else { - stopping = true - freeMemory() - if (success) { - Option(mapStatus) - } else { - // The map task failed, so delete our output data. - shuffleBlockManager.removeDataByMap(dep.shuffleId, mapId) - None - } - } - } finally { - freeMemory() - val startTime = System.nanoTime() - context.taskMetrics().shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - startTime)) - } - } -} - private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager { private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) From 026b4977a465b4c47af43e6365de4158d3d10ab7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 22:55:39 -0700 Subject: [PATCH 16/92] Re-use a buffer in UnsafeShuffleWriter --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index bf368d4a11526..ea17aedb919ef 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -193,6 +193,7 @@ private long[] writeSortedRecordsToFile( int currentPartition = -1; BlockObjectWriter writer = null; + final byte[] arr = new byte[SER_BUFFER_SIZE]; while (sortedRecords.hasNext()) { final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); final int partition = (int) recordPointer.keyPrefix; @@ -211,16 +212,14 @@ private long[] writeSortedRecordsToFile( final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8); - // TODO: re-use a buffer or avoid double-buffering entirely - final byte[] arr = new byte[recordLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + 16, - arr, - PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength); + PlatformDependent.copyMemory( + baseObject, + baseOffset + 16, + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength); assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr); + writer.write(arr, 0, recordLength); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } From 1433b42961ff7bf777b5a966f634822144d13f7c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 2 May 2015 00:12:44 -0700 Subject: [PATCH 17/92] Store record length as int instead of long. --- .../spark/shuffle/unsafe/UnsafeShuffleWriter.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index ea17aedb919ef..3a5064f03cced 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -159,16 +159,16 @@ private Iterator sortRecords( // TODO: we should run the partition extraction function _now_, at insert time, rather than // requiring it to be stored alongisde the data, since this may lead to double storage // Need 8 bytes to store the prefix (for later retrieval in the prefix computer), plus - // 8 to store the record length (TODO: can store as an int instead). - ensureSpaceInDataPage(serializedRecordSize + 8 + 8); + // 4 to store the record length. + ensureSpaceInDataPage(serializedRecordSize + 8 + 4); final long recordAddress = memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); final Object baseObject = currentPage.getBaseObject(); PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, partitionId); currentPagePosition += 8; - PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, serializedRecordSize); - currentPagePosition += 8; + PlatformDependent.UNSAFE.putInt(baseObject, currentPagePosition, serializedRecordSize); + currentPagePosition += 4; PlatformDependent.copyMemory( serArray, PlatformDependent.BYTE_ARRAY_OFFSET, @@ -214,7 +214,7 @@ private long[] writeSortedRecordsToFile( final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8); PlatformDependent.copyMemory( baseObject, - baseOffset + 16, + baseOffset + 8 + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, recordLength); From 240864c9f860d41c9d2ad51ff78c9160e6e90992 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 2 May 2015 00:20:40 -0700 Subject: [PATCH 18/92] Remove PrefixComputer and require prefix to be specified as part of insert() --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 26 ++++--------------- .../spark/unsafe/sort/UnsafeSorter.java | 13 +--------- .../spark/unsafe/sort/UnsafeSorterSuite.java | 17 ++++-------- 3 files changed, 11 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 3a5064f03cced..4d65016577872 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -134,7 +134,6 @@ private Iterator sortRecords( final UnsafeSorter sorter = new UnsafeSorter( memoryManager, RECORD_COMPARATOR, - PREFIX_COMPUTER, PREFIX_COMPARATOR, 4096 // Initial size (TODO: tune this!) ); @@ -156,17 +155,12 @@ private Iterator sortRecords( final int serializedRecordSize = serByteBuffer.position(); assert (serializedRecordSize > 0); - // TODO: we should run the partition extraction function _now_, at insert time, rather than - // requiring it to be stored alongisde the data, since this may lead to double storage - // Need 8 bytes to store the prefix (for later retrieval in the prefix computer), plus - // 4 to store the record length. - ensureSpaceInDataPage(serializedRecordSize + 8 + 4); + // Need 4 bytes to store the record length. + ensureSpaceInDataPage(serializedRecordSize + 4); final long recordAddress = memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); final Object baseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, partitionId); - currentPagePosition += 8; PlatformDependent.UNSAFE.putInt(baseObject, currentPagePosition, serializedRecordSize); currentPagePosition += 4; PlatformDependent.copyMemory( @@ -177,7 +171,7 @@ private Iterator sortRecords( serializedRecordSize); currentPagePosition += serializedRecordSize; - sorter.insertRecord(recordAddress); + sorter.insertRecord(recordAddress, partitionId); } return sorter.getSortedIterator(); @@ -211,10 +205,10 @@ private long[] writeSortedRecordsToFile( final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); - final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8); + final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); PlatformDependent.copyMemory( baseObject, - baseOffset + 8 + 4, + baseOffset + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, recordLength); @@ -262,16 +256,6 @@ public int compare( } }; - private static final PrefixComputer PREFIX_COMPUTER = new PrefixComputer() { - @Override - public long computePrefix(Object baseObject, long baseOffset) { - // TODO: should the prefix be computed when inserting the record pointer rather than being - // read from the record itself? May be more efficient in terms of space, etc, and is a simple - // change. - return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); - } - }; - private static final PrefixComparator PREFIX_COMPARATOR = new PrefixComparator() { @Override public int compare(long prefix1, long prefix2) { diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 7795ee6a5f0e2..adbbc0b1f3cb8 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -74,13 +74,6 @@ public abstract int compare( long rightBaseOffset); } - /** - * Given a pointer to a record, computes a prefix. - */ - public static abstract class PrefixComputer { - public abstract long computePrefix(Object baseObject, long baseOffset); - } - /** * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific * comparisons, such as lexicographic comparison for strings. @@ -90,7 +83,6 @@ public static abstract class PrefixComparator { } private final TaskMemoryManager memoryManager; - private final PrefixComputer prefixComputer; private final Sorter sorter; private final Comparator sortComparator; @@ -116,13 +108,11 @@ private void expandSortBuffer(int newSize) { public UnsafeSorter( final TaskMemoryManager memoryManager, final RecordComparator recordComparator, - PrefixComputer prefixComputer, final PrefixComparator prefixComparator, int initialSize) { assert (initialSize > 0); this.sortBuffer = new long[initialSize * 2]; this.memoryManager = memoryManager; - this.prefixComputer = prefixComputer; this.sorter = new Sorter(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new Comparator() { @@ -149,13 +139,12 @@ public int compare(RecordPointerAndKeyPrefix left, RecordPointerAndKeyPrefix rig * * @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}. */ - public void insertRecord(long objectAddress) { + public void insertRecord(long objectAddress, long keyPrefix) { if (sortBufferInsertPosition + 2 == sortBuffer.length) { expandSortBuffer(sortBuffer.length * 2); } final Object baseObject = memoryManager.getPage(objectAddress); final long baseOffset = memoryManager.getOffsetInPage(objectAddress); - final long keyPrefix = prefixComputer.computePrefix(baseObject, baseOffset); sortBuffer[sortBufferInsertPosition] = objectAddress; sortBufferInsertPosition++; sortBuffer[sortBufferInsertPosition] = keyPrefix; diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index 2f88df1210bbc..aed115f83a368 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -49,7 +49,6 @@ public void testSortingEmptyInput() { final UnsafeSorter sorter = new UnsafeSorter( new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), mock(UnsafeSorter.RecordComparator.class), - mock(UnsafeSorter.PrefixComputer.class), mock(UnsafeSorter.PrefixComparator.class), 100); final Iterator iter = sorter.getSortedIterator(); @@ -104,14 +103,6 @@ public int compare( }; // Compute key prefixes based on the records' partition ids final HashPartitioner hashPartitioner = new HashPartitioner(4); - final UnsafeSorter.PrefixComputer prefixComputer = new UnsafeSorter.PrefixComputer() { - @Override - public long computePrefix(Object baseObject, long baseOffset) { - final String str = getStringFromDataPage(baseObject, baseOffset); - final int partitionId = hashPartitioner.getPartition(str); - return (long) partitionId; - } - }; // Use integer comparison for comparing prefixes (which are partition ids, in this case) final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { @Override @@ -119,15 +110,17 @@ public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; } }; - final UnsafeSorter sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComputer, - prefixComparator, dataToSort.length); + final UnsafeSorter sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComparator, + dataToSort.length); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { // position now points to the start of a record (which holds its length). final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); - sorter.insertRecord(address); + final String str = getStringFromDataPage(baseObject, position); + final int partitionId = hashPartitioner.getPartition(str); + sorter.insertRecord(address, partitionId); position += 8 + recordLength; } final Iterator iter = sorter.getSortedIterator(); From bfc12d30b682061fed539dd652c3bc0420643d21 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 12:52:15 -0700 Subject: [PATCH 19/92] Add tests for serializer relocation property. I verified that the Kryo tests will fail if we remove the auto-reset check in KryoSerializer. I also checked that this test fails if we mistakenly enable this flag for JavaSerializer. This demonstrates that the test case is actually capable of detecting the types of bugs that it's trying to prevent. Of course, it's possible that certain bugs will only surface when serializing specific data types, so we'll still have to be cautious when overriding `supportsRelocationOfSerializedObjects` for new serializers. --- .../spark/serializer/KryoSerializer.scala | 4 +- .../SerializerPropertiesSuite.scala | 103 ++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index f6c17e362f9b3..14b3890fac01a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -127,7 +127,9 @@ class KryoSerializer(conf: SparkConf) } override def supportsRelocationOfSerializedObjects: Boolean = { - // TODO: we should have a citation / explanatory comment here clarifying _why_ this is the case + // If auto-flush is disabled, then Kryo may store references to duplicate occurrences of objects + // in the stream rather than writing those objects' serialized bytes, breaking relocation. See + // https://groups.google.com/d/msg/kryo-users/6ZUSyfjjtdo/FhGG1KHDXPgJ for more details. newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset() } } diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala new file mode 100644 index 0000000000000..a117619b04a43 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.util.Random + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset + +private case class MyCaseClass(foo: Int, bar: String) + +class SerializerPropertiesSuite extends FunSuite { + + test("JavaSerializer does not support relocation") { + testSupportsRelocationOfSerializedObjects(new JavaSerializer(new SparkConf())) + } + + test("KryoSerializer supports relocation when auto-reset is enabled") { + val ser = new KryoSerializer(new SparkConf) + assert(ser.newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset()) + testSupportsRelocationOfSerializedObjects(ser) + } + + test("KryoSerializer does not support relocation when auto-reset is disabled") { + val conf = new SparkConf().set("spark.kryo.registrator", + classOf[RegistratorWithoutAutoReset].getName) + val ser = new KryoSerializer(conf) + assert(!ser.newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset()) + testSupportsRelocationOfSerializedObjects(ser) + } + + def testSupportsRelocationOfSerializedObjects(serializer: Serializer): Unit = { + val NUM_TRIALS = 100 + if (!serializer.supportsRelocationOfSerializedObjects) { + return + } + val rand = new Random(42) + val randomFunctions: Seq[() => Any] = Seq( + () => rand.nextInt(), + () => rand.nextString(rand.nextInt(10)), + () => rand.nextDouble(), + () => rand.nextBoolean(), + () => (rand.nextInt(), rand.nextString(rand.nextInt(10))), + () => MyCaseClass(rand.nextInt(), rand.nextString(rand.nextInt(10))), + () => { + val x = MyCaseClass(rand.nextInt(), rand.nextString(rand.nextInt(10))) + (x, x) + } + ) + def generateRandomItem(): Any = { + randomFunctions(rand.nextInt(randomFunctions.size)).apply() + } + + for (_ <- 1 to NUM_TRIALS) { + val items = { + // Make sure that we have duplicate occurrences of the same object in the stream: + val randomItems = Seq.fill(10)(generateRandomItem()) + randomItems ++ randomItems.take(5) + } + val baos = new ByteArrayOutputStream() + val serStream = serializer.newInstance().serializeStream(baos) + def serializeItem(item: Any): Array[Byte] = { + val itemStartOffset = baos.toByteArray.length + serStream.writeObject(item) + serStream.flush() + val itemEndOffset = baos.toByteArray.length + baos.toByteArray.slice(itemStartOffset, itemEndOffset).clone() + } + val itemsAndSerializedItems: Seq[(Any, Array[Byte])] = { + val serItems = items.map { + item => (item, serializeItem(item)) + } + serStream.close() + rand.shuffle(serItems) + } + val reorderedSerializedData: Array[Byte] = itemsAndSerializedItems.flatMap(_._2).toArray + val deserializedItemsStream = serializer.newInstance().deserializeStream( + new ByteArrayInputStream(reorderedSerializedData)) + assert(deserializedItemsStream.asIterator.toSeq === itemsAndSerializedItems.map(_._1)) + deserializedItemsStream.close() + } + } + +} From b8a09fe831e74f5264b3a24ad0ccbef4209178d4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 12:57:07 -0700 Subject: [PATCH 20/92] Back out accidental log4j.properties change --- .gitignore | 1 - core/src/test/resources/log4j.properties | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 46a2a3a3f190d..d54d21b802be8 100644 --- a/.gitignore +++ b/.gitignore @@ -29,7 +29,6 @@ conf/*.properties conf/*.conf conf/*.xml conf/slaves -core/build/py4j/ docs/_site docs/api target/ diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 9512ac1ac79c3..eb3b1999eb996 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=DEBUG, file +log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log From c2fca171c26b30b100d6d717f7173d8f5705c341 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 13:08:35 -0700 Subject: [PATCH 21/92] Small refactoring of SerializerPropertiesSuite to enable test re-use: This lays some groundwork for re-using this test logic for serializers defined in other subprojects (those projects can just declare a test-jar dependency on Spark core). --- .../SerializerPropertiesSuite.scala | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index a117619b04a43..dbd831c3f6056 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -21,23 +21,25 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.util.Random -import org.scalatest.FunSuite +import org.scalatest.{Assertions, FunSuite} import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset -private case class MyCaseClass(foo: Int, bar: String) class SerializerPropertiesSuite extends FunSuite { + import SerializerPropertiesSuite._ + test("JavaSerializer does not support relocation") { - testSupportsRelocationOfSerializedObjects(new JavaSerializer(new SparkConf())) + val ser = new JavaSerializer(new SparkConf()) + testSupportsRelocationOfSerializedObjects(ser, generateRandomItem) } test("KryoSerializer supports relocation when auto-reset is enabled") { val ser = new KryoSerializer(new SparkConf) assert(ser.newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset()) - testSupportsRelocationOfSerializedObjects(ser) + testSupportsRelocationOfSerializedObjects(ser, generateRandomItem) } test("KryoSerializer does not support relocation when auto-reset is disabled") { @@ -45,15 +47,14 @@ class SerializerPropertiesSuite extends FunSuite { classOf[RegistratorWithoutAutoReset].getName) val ser = new KryoSerializer(conf) assert(!ser.newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset()) - testSupportsRelocationOfSerializedObjects(ser) + testSupportsRelocationOfSerializedObjects(ser, generateRandomItem) } - def testSupportsRelocationOfSerializedObjects(serializer: Serializer): Unit = { - val NUM_TRIALS = 100 - if (!serializer.supportsRelocationOfSerializedObjects) { - return - } - val rand = new Random(42) +} + +object SerializerPropertiesSuite extends Assertions { + + def generateRandomItem(rand: Random): Any = { val randomFunctions: Seq[() => Any] = Seq( () => rand.nextInt(), () => rand.nextString(rand.nextInt(10)), @@ -66,14 +67,21 @@ class SerializerPropertiesSuite extends FunSuite { (x, x) } ) - def generateRandomItem(): Any = { - randomFunctions(rand.nextInt(randomFunctions.size)).apply() - } + randomFunctions(rand.nextInt(randomFunctions.size)).apply() + } + def testSupportsRelocationOfSerializedObjects( + serializer: Serializer, + generateRandomItem: Random => Any): Unit = { + if (!serializer.supportsRelocationOfSerializedObjects) { + return + } + val NUM_TRIALS = 10 + val rand = new Random(42) for (_ <- 1 to NUM_TRIALS) { val items = { // Make sure that we have duplicate occurrences of the same object in the stream: - val randomItems = Seq.fill(10)(generateRandomItem()) + val randomItems = Seq.fill(10)(generateRandomItem(rand)) randomItems ++ randomItems.take(5) } val baos = new ByteArrayOutputStream() @@ -99,5 +107,6 @@ class SerializerPropertiesSuite extends FunSuite { deserializedItemsStream.close() } } - } + +private case class MyCaseClass(foo: Int, bar: String) \ No newline at end of file From f17fa8fbf03ba2d0ec0ecf5050668a914c9fae44 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 13:18:39 -0700 Subject: [PATCH 22/92] Add missing newline --- .../org/apache/spark/serializer/SerializerPropertiesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index dbd831c3f6056..b6848c3b19f51 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -109,4 +109,4 @@ object SerializerPropertiesSuite extends Assertions { } } -private case class MyCaseClass(foo: Int, bar: String) \ No newline at end of file +private case class MyCaseClass(foo: Int, bar: String) From 89585847a4c0340caf7f681ea24b3b00f029a295 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 14:42:36 -0700 Subject: [PATCH 23/92] Fix bug in calculating free space in current page. This broke off-heap mode. --- .../spark/shuffle/unsafe/UnsafeShuffleWriter.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 4d65016577872..9554298c0f3f8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -71,7 +71,7 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private MapStatus mapStatus = null; private MemoryBlock currentPage = null; - private long currentPagePosition = PAGE_SIZE; + private long currentPagePosition = -1; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -110,11 +110,17 @@ public void write(scala.collection.Iterator> records) { } private void ensureSpaceInDataPage(long requiredSpace) throws Exception { + final long spaceInCurrentPage; + if (currentPage != null) { + spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); + } else { + spaceInCurrentPage = 0; + } if (requiredSpace > PAGE_SIZE) { // TODO: throw a more specific exception? throw new Exception("Required space " + requiredSpace + " is greater than page size (" + PAGE_SIZE + ")"); - } else if (requiredSpace > (PAGE_SIZE - currentPagePosition)) { + } else if (requiredSpace > spaceInCurrentPage) { currentPage = memoryManager.allocatePage(PAGE_SIZE); currentPagePosition = currentPage.getBaseOffset(); allocatedPages.add(currentPage); From 595923aec7c0a94029ab2077bf0d9362cc56441d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 14:55:58 -0700 Subject: [PATCH 24/92] Remove some unused variables. --- .../main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index adbbc0b1f3cb8..092e26f4ee1fc 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -82,7 +82,6 @@ public static abstract class PrefixComparator { public abstract int compare(long prefix1, long prefix2); } - private final TaskMemoryManager memoryManager; private final Sorter sorter; private final Comparator sortComparator; @@ -112,7 +111,6 @@ public UnsafeSorter( int initialSize) { assert (initialSize > 0); this.sortBuffer = new long[initialSize * 2]; - this.memoryManager = memoryManager; this.sorter = new Sorter(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new Comparator() { @@ -143,8 +141,6 @@ public void insertRecord(long objectAddress, long keyPrefix) { if (sortBufferInsertPosition + 2 == sortBuffer.length) { expandSortBuffer(sortBuffer.length * 2); } - final Object baseObject = memoryManager.getPage(objectAddress); - final long baseOffset = memoryManager.getOffsetInPage(objectAddress); sortBuffer[sortBufferInsertPosition] = objectAddress; sortBufferInsertPosition++; sortBuffer[sortBufferInsertPosition] = keyPrefix; From 5e100b2db9c0cb986f7b90fd58d7274b83236241 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 19:19:38 -0700 Subject: [PATCH 25/92] Super-messy WIP on external sort --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 79 ++------ .../sort/UnsafeExternalSortSpillMerger.java | 106 +++++++++++ .../unsafe/sort/UnsafeExternalSorter.java | 172 ++++++++++++++++++ .../spark/unsafe/sort/UnsafeSorter.java | 54 ++++-- .../unsafe/sort/UnsafeSorterSpillReader.java | 93 ++++++++++ .../unsafe/sort/UnsafeSorterSpillWriter.java | 86 +++++++++ .../spark/util/collection/Spillable.scala | 14 +- .../sort/UnsafeExternalSorterSuite.java | 136 ++++++++++++++ 8 files changed, 663 insertions(+), 77 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java create mode 100644 core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 9554298c0f3f8..0ea11e823d1d4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -17,6 +17,9 @@ package org.apache.spark.shuffle.unsafe; +import org.apache.spark.*; +import org.apache.spark.unsafe.sort.UnsafeExternalSortSpillMerger; +import org.apache.spark.unsafe.sort.UnsafeExternalSorter; import scala.Option; import scala.Product2; import scala.reflect.ClassTag; @@ -30,10 +33,6 @@ import com.esotericsoftware.kryo.io.ByteBufferOutputStream; -import org.apache.spark.Partitioner; -import org.apache.spark.ShuffleDependency; -import org.apache.spark.SparkEnv; -import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -54,7 +53,6 @@ // IntelliJ gets confused and claims that this class should be abstract, but this actually compiles public class UnsafeShuffleWriter implements ShuffleWriter { - private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @@ -70,9 +68,6 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private final int fileBufferSize; private MapStatus mapStatus = null; - private MemoryBlock currentPage = null; - private long currentPagePosition = -1; - /** * Are we in the process of stopping? Because map tasks can call stop() with success = true * and then call stop() with success = false if they get an exception, we want to make sure @@ -109,39 +104,20 @@ public void write(scala.collection.Iterator> records) { } } - private void ensureSpaceInDataPage(long requiredSpace) throws Exception { - final long spaceInCurrentPage; - if (currentPage != null) { - spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); - } else { - spaceInCurrentPage = 0; - } - if (requiredSpace > PAGE_SIZE) { - // TODO: throw a more specific exception? - throw new Exception("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); - } else if (requiredSpace > spaceInCurrentPage) { - currentPage = memoryManager.allocatePage(PAGE_SIZE); - currentPagePosition = currentPage.getBaseOffset(); - allocatedPages.add(currentPage); - } - } - private void freeMemory() { - final Iterator iter = allocatedPages.iterator(); - while (iter.hasNext()) { - memoryManager.freePage(iter.next()); - iter.remove(); - } + // TODO: free sorter memory } - private Iterator sortRecords( - scala.collection.Iterator> records) throws Exception { - final UnsafeSorter sorter = new UnsafeSorter( + private Iterator sortRecords( + scala.collection.Iterator> records) throws Exception { + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, + SparkEnv$.MODULE$.get().shuffleMemoryManager(), + SparkEnv$.MODULE$.get().blockManager(), RECORD_COMPARATOR, PREFIX_COMPARATOR, - 4096 // Initial size (TODO: tune this!) + 4096, // Initial size (TODO: tune this!) + SparkEnv$.MODULE$.get().conf() ); final byte[] serArray = new byte[SER_BUFFER_SIZE]; @@ -161,30 +137,16 @@ private Iterator sortRecords( final int serializedRecordSize = serByteBuffer.position(); assert (serializedRecordSize > 0); - // Need 4 bytes to store the record length. - ensureSpaceInDataPage(serializedRecordSize + 4); - - final long recordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); - final Object baseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putInt(baseObject, currentPagePosition, serializedRecordSize); - currentPagePosition += 4; - PlatformDependent.copyMemory( - serArray, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - currentPagePosition, - serializedRecordSize); - currentPagePosition += serializedRecordSize; - sorter.insertRecord(recordAddress, partitionId); + sorter.insertRecord( + serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } return sorter.getSortedIterator(); } private long[] writeSortedRecordsToFile( - Iterator sortedRecords) throws IOException { + Iterator sortedRecords) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); final ShuffleBlockId blockId = new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); @@ -195,7 +157,7 @@ private long[] writeSortedRecordsToFile( final byte[] arr = new byte[SER_BUFFER_SIZE]; while (sortedRecords.hasNext()) { - final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); + final UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix recordPointer = sortedRecords.next(); final int partition = (int) recordPointer.keyPrefix; assert (partition >= currentPartition); if (partition != currentPartition) { @@ -209,17 +171,14 @@ private long[] writeSortedRecordsToFile( blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics); } - final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); - final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); - final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); PlatformDependent.copyMemory( - baseObject, - baseOffset + 4, + recordPointer.baseObject, + recordPointer.baseOffset + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength); + recordPointer.recordLength); assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, recordLength); + writer.write(arr, 0, recordPointer.recordLength); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java new file mode 100644 index 0000000000000..89928ffaa448d --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.PriorityQueue; + +import static org.apache.spark.unsafe.sort.UnsafeSorter.*; + +public final class UnsafeExternalSortSpillMerger { + + private final PriorityQueue priorityQueue; + + public static abstract class MergeableIterator { + public abstract boolean hasNext(); + + public abstract void advanceRecord(); + + public abstract long getPrefix(); + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + } + + public static final class RecordAddressAndKeyPrefix { + public Object baseObject; + public long baseOffset; + public int recordLength; + public long keyPrefix; + } + + public UnsafeExternalSortSpillMerger( + final RecordComparator recordComparator, + final UnsafeSorter.PrefixComparator prefixComparator) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(MergeableIterator left, MergeableIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getPrefix(), right.getPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + priorityQueue = new PriorityQueue(10, comparator); + } + + public void addSpill(MergeableIterator spillReader) { + priorityQueue.add(spillReader); + } + + public Iterator getSortedIterator() { + return new Iterator() { + + private MergeableIterator spillReader; + private final RecordAddressAndKeyPrefix record = new RecordAddressAndKeyPrefix(); + + @Override + public boolean hasNext() { + return spillReader.hasNext() || !priorityQueue.isEmpty(); + } + + @Override + public RecordAddressAndKeyPrefix next() { + if (spillReader != null) { + if (spillReader.hasNext()) { + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.poll(); + record.baseObject = spillReader.getBaseObject(); + record.baseOffset = spillReader.getBaseOffset(); + record.keyPrefix = spillReader.getPrefix(); + return record; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 0000000000000..613d07cf6a316 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.SparkConf; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedList; + +import static org.apache.spark.unsafe.sort.UnsafeSorter.*; + +/** + * External sorter based on {@link UnsafeSorter}. + */ +public final class UnsafeExternalSorter { + + private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private UnsafeSorter sorter; + + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final LinkedList allocatedPages = new LinkedList(); + private final boolean spillingEnabled; + private final int fileBufferSize; + private ShuffleWriteMetrics writeMetrics; + + + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + + private final LinkedList spillWriters = + new LinkedList(); + + public UnsafeExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + SparkConf conf) { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true); + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + openSorter(); + } + + // TODO: metrics tracking + integration with shuffle write metrics + + private void openSorter() { + this.writeMetrics = new ShuffleWriteMetrics(); + // TODO: connect write metrics to task metrics? + this.sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComparator, initialSize); + } + + @VisibleForTesting + public void spill() throws IOException { + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); + final Iterator sortedRecords = sorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); + final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); + final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); + final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + spillWriter.write(baseObject, baseOffset, recordLength, recordPointer.keyPrefix); + } + spillWriter.close(); + sorter = null; + freeMemory(); + openSorter(); + } + + private void freeMemory() { + final Iterator iter = allocatedPages.iterator(); + while (iter.hasNext()) { + memoryManager.freePage(iter.next()); + iter.remove(); + } + } + + private void ensureSpaceInDataPage(int requiredSpace) throws Exception { + final long spaceInCurrentPage; + if (currentPage != null) { + spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); + } else { + spaceInCurrentPage = 0; + } + if (requiredSpace > PAGE_SIZE) { + // TODO: throw a more specific exception? + throw new Exception("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else if (requiredSpace > spaceInCurrentPage) { + if (spillingEnabled && shuffleMemoryManager.tryToAcquire(PAGE_SIZE) < PAGE_SIZE) { + spill(); + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + } + + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws Exception { + // Need 4 bytes to store the record length. + ensureSpaceInDataPage(lengthInBytes + 4); + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + + sorter.insertRecord(recordAddress, prefix); + } + + public Iterator getSortedIterator() throws IOException { + final UnsafeExternalSortSpillMerger spillMerger = + new UnsafeExternalSortSpillMerger(recordComparator, prefixComparator); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpill(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + spillMerger.addSpill(sorter.getMergeableIterator()); + return spillMerger.getSortedIterator(); + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 092e26f4ee1fc..1801585e2ed84 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -43,17 +43,6 @@ public static final class RecordPointerAndKeyPrefix { * A key prefix, for use in comparisons. */ public long keyPrefix; - - // TODO: this was a carryover from test code; may want to remove this - @Override - public int hashCode() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean equals(Object obj) { - throw new UnsupportedOperationException(); - } } /** @@ -82,6 +71,7 @@ public static abstract class PrefixComparator { public abstract int compare(long prefix1, long prefix2); } + private final TaskMemoryManager memoryManager; private final Sorter sorter; private final Comparator sortComparator; @@ -96,7 +86,6 @@ public static abstract class PrefixComparator { */ private int sortBufferInsertPosition = 0; - private void expandSortBuffer(int newSize) { assert (newSize > sortBuffer.length); final long[] oldBuffer = sortBuffer; @@ -111,6 +100,7 @@ public UnsafeSorter( int initialSize) { assert (initialSize > 0); this.sortBuffer = new long[initialSize * 2]; + this.memoryManager = memoryManager; this.sorter = new Sorter(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new Comparator() { @@ -176,4 +166,44 @@ public void remove() { } }; } + + public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() { + sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); + return new UnsafeExternalSortSpillMerger.MergeableIterator() { + + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void advanceRecord() { + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public long getPrefix() { + return keyPrefix; + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + }; + } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 0000000000000..e2d5e6a8faa10 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import com.google.common.io.ByteStreams; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.JavaSerializerInstance; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; +import scala.Tuple2; + +import java.io.*; + +public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger.MergeableIterator { + + private final File file; + private InputStream in; + private DataInputStream din; + + private long keyPrefix; + private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? + private final Object baseObject = arr; + private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + this.file = file; + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + assert (file.length() > 0); + advanceRecord(); + } + + @Override + public boolean hasNext() { + return (in != null); + } + + @Override + public void advanceRecord() { + try { + final int recordLength = din.readInt(); + if (recordLength == UnsafeSorterSpillWriter.EOF_MARKER) { + in.close(); + in = null; + return; + } + keyPrefix = din.readLong(); + ByteStreams.readFully(in, arr, 0, recordLength); + + } catch (Exception e) { + PlatformDependent.throwException(e); + } + throw new IllegalStateException(); + } + + @Override + public long getPrefix() { + return keyPrefix; + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 0000000000000..fdda38d3f1c47 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.JavaSerializerInstance; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; +import scala.Tuple2; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.IOException; + +public final class UnsafeSorterSpillWriter { + + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + public static final int EOF_MARKER = -1; + byte[] arr = new byte[SER_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + BlockObjectWriter writer; + DataOutputStream dos; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + // Dummy serializer: + final SerializerInstance ser = new JavaSerializerInstance(0, false, null); + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); + dos = new DataOutputStream(writer); + } + + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + PlatformDependent.copyMemory( + baseObject, + baseOffset + 4, + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength); + dos.writeInt(recordLength); + dos.writeLong(keyPrefix); + writer.write(arr, 0, recordLength); + // TODO: add a test that detects whether we leave this call out: + writer.recordWritten(); + } + + public void close() throws IOException { + dos.writeInt(EOF_MARKER); + writer.commitAndClose(); + arr = null; + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 747ecf075a397..841a4cd791c4c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -20,11 +20,20 @@ package org.apache.spark.util.collection import org.apache.spark.Logging import org.apache.spark.SparkEnv +private[spark] object Spillable { + // Initial threshold for the size of a collection before we start tracking its memory usage + val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) +} + /** * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ private[spark] trait Spillable[C] extends Logging { + + import Spillable._ + /** * Spills the current in-memory collection to disk, and releases the memory. * @@ -42,11 +51,6 @@ private[spark] trait Spillable[C] extends Logging { // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - // Initial threshold for the size of a collection before we start tracking its memory usage - // Exposed for testing - private[this] val initialMemoryThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) - // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java new file mode 100644 index 0000000000000..e4376f1cea4fc --- /dev/null +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Iterator; + +import static org.mockito.Mockito.mock; + +public class UnsafeExternalSorterSuite { + private static String getStringFromDataPage(Object baseObject, long baseOffset) { + final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + final byte[] strBytes = new byte[strLength]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + 8, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + /** + * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. + */ + @Test + public void testSortingOnlyByPartitionId() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); + position += 8; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + final UnsafeSorter sorter = new UnsafeSorter( + memoryManager, + recordComparator, + prefixComparator, + dataToSort.length); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + final String str = getStringFromDataPage(baseObject, position); + final int partitionId = hashPartitioner.getPartition(str); + sorter.insertRecord(address, partitionId); + position += 8 + recordLength; + } + final Iterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + final UnsafeSorter.RecordPointerAndKeyPrefix pointerAndPrefix = iter.next(); + final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); + final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); + final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); + Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); + Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " + + prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix); + prevPrefix = pointerAndPrefix.keyPrefix; + iterLength++; + } + Assert.assertEquals(dataToSort.length, iterLength); + } + +} From 2776acaf368616f96f9c42f79744293a70b7a08a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 21:20:47 -0700 Subject: [PATCH 26/92] First passing test for ExternalSorter. --- .../sort/UnsafeExternalSortSpillMerger.java | 10 +- .../unsafe/sort/UnsafeExternalSorter.java | 5 +- .../spark/unsafe/sort/UnsafeSorter.java | 10 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 25 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 54 ++++- .../spark/storage/BlockObjectWriter.scala | 8 +- .../sort/UnsafeExternalSorterSuite.java | 218 ++++++++++-------- pom.xml | 2 +- 8 files changed, 209 insertions(+), 123 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java index 89928ffaa448d..c6bd4ee9df4ff 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java @@ -30,7 +30,7 @@ public final class UnsafeExternalSortSpillMerger { public static abstract class MergeableIterator { public abstract boolean hasNext(); - public abstract void advanceRecord(); + public abstract void loadNextRecord(); public abstract long getPrefix(); @@ -68,6 +68,9 @@ public int compare(MergeableIterator left, MergeableIterator right) { } public void addSpill(MergeableIterator spillReader) { + if (spillReader.hasNext()) { + spillReader.loadNextRecord(); + } priorityQueue.add(spillReader); } @@ -79,17 +82,18 @@ public Iterator getSortedIterator() { @Override public boolean hasNext() { - return spillReader.hasNext() || !priorityQueue.isEmpty(); + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); } @Override public RecordAddressAndKeyPrefix next() { if (spillReader != null) { if (spillReader.hasNext()) { + spillReader.loadNextRecord(); priorityQueue.add(spillReader); } } - spillReader = priorityQueue.poll(); + spillReader = priorityQueue.remove(); record.baseObject = spillReader.getBaseObject(); record.baseOffset = spillReader.getBaseOffset(); record.keyPrefix = spillReader.getPrefix(); diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java index 613d07cf6a316..42669055bcf1c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java @@ -38,7 +38,6 @@ public final class UnsafeExternalSorter { private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this - private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; @@ -92,6 +91,7 @@ private void openSorter() { public void spill() throws IOException { final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); + spillWriters.add(spillWriter); final Iterator sortedRecords = sorter.getSortedIterator(); while (sortedRecords.hasNext()) { final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); @@ -110,8 +110,11 @@ private void freeMemory() { final Iterator iter = allocatedPages.iterator(); while (iter.hasNext()) { memoryManager.freePage(iter.next()); + shuffleMemoryManager.release(PAGE_SIZE); iter.remove(); } + currentPage = null; + currentPagePosition = -1; } private void ensureSpaceInDataPage(int requiredSpace) throws Exception { diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 1801585e2ed84..0f844c1997668 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -169,7 +169,8 @@ public void remove() { public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new UnsafeExternalSortSpillMerger.MergeableIterator() { + UnsafeExternalSortSpillMerger.MergeableIterator iter = + new UnsafeExternalSortSpillMerger.MergeableIterator() { private int position = 0; private Object baseObject; @@ -182,12 +183,12 @@ public boolean hasNext() { } @Override - public void advanceRecord() { + public void loadNextRecord() { final long recordPointer = sortBuffer[position]; - baseObject = memoryManager.getPage(recordPointer); - baseOffset = memoryManager.getOffsetInPage(recordPointer); keyPrefix = sortBuffer[position + 1]; position += 2; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer); } @Override @@ -205,5 +206,6 @@ public long getBaseOffset() { return baseOffset; } }; + return iter; } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java index e2d5e6a8faa10..7c696240aaa73 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -18,15 +18,9 @@ package org.apache.spark.unsafe.sort; import com.google.common.io.ByteStreams; -import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.serializer.JavaSerializerInstance; -import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; -import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; -import scala.Tuple2; import java.io.*; @@ -39,6 +33,7 @@ public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger private long keyPrefix; private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? private final Object baseObject = arr; + private int nextRecordLength; private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( @@ -46,11 +41,11 @@ public UnsafeSorterSpillReader( File file, BlockId blockId) throws IOException { this.file = file; + assert (file.length() > 0); final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); this.in = blockManager.wrapForCompression(blockId, bs); this.din = new DataInputStream(this.in); - assert (file.length() > 0); - advanceRecord(); + nextRecordLength = din.readInt(); } @Override @@ -59,21 +54,19 @@ public boolean hasNext() { } @Override - public void advanceRecord() { + public void loadNextRecord() { try { - final int recordLength = din.readInt(); - if (recordLength == UnsafeSorterSpillWriter.EOF_MARKER) { + keyPrefix = din.readLong(); + ByteStreams.readFully(in, arr, 0, nextRecordLength); + nextRecordLength = din.readInt(); + if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { in.close(); in = null; - return; + din = null; } - keyPrefix = din.readLong(); - ByteStreams.readFully(in, arr, 0, recordLength); - } catch (Exception e) { PlatformDependent.throwException(e); } - throw new IllegalStateException(); } @Override diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java index fdda38d3f1c47..e0649122ac09c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java @@ -18,7 +18,9 @@ package org.apache.spark.unsafe.sort; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DeserializationStream; import org.apache.spark.serializer.JavaSerializerInstance; +import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; @@ -26,10 +28,10 @@ import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; import scala.Tuple2; +import scala.reflect.ClassTag; -import java.io.DataOutputStream; -import java.io.File; -import java.io.IOException; +import java.io.*; +import java.nio.ByteBuffer; public final class UnsafeSorterSpillWriter { @@ -51,7 +53,47 @@ public UnsafeSorterSpillWriter( this.file = spilledFileInfo._2(); this.blockId = spilledFileInfo._1(); // Dummy serializer: - final SerializerInstance ser = new JavaSerializerInstance(0, false, null); + final SerializerInstance ser = new SerializerInstance() { + @Override + public SerializationStream serializeStream(OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { + + } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + return null; + } + + @Override + public void close() { + + } + }; + } + + @Override + public ByteBuffer serialize(T t, ClassTag ev1) { + return null; + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag ev1) { + return null; + } + }; writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); dos = new DataOutputStream(writer); } @@ -61,14 +103,14 @@ public void write( long baseOffset, int recordLength, long keyPrefix) throws IOException { + dos.writeInt(recordLength); + dos.writeLong(keyPrefix); PlatformDependent.copyMemory( baseObject, baseOffset + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, recordLength); - dos.writeInt(recordLength); - dos.writeLong(keyPrefix); writer.write(arr, 0, recordLength); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 499dd97c0656a..f273a31706cd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -223,7 +223,13 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(b: Int): Unit = throw new UnsupportedOperationException() + override def write(b: Int): Unit = { + if (!initialized) { + open() + } + + bs.write(b) + } override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { if (!initialized) { diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java index e4376f1cea4fc..4f2aa9b895c01 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,29 +19,117 @@ import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; -import java.util.Arrays; +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; import java.util.Iterator; +import java.util.UUID; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.*; +import static org.mockito.AdditionalAnswers.*; public class UnsafeExternalSorterSuite { - private static String getStringFromDataPage(Object baseObject, long baseOffset) { - final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); - final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + 8, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); - return new String(strBytes); + + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + + ShuffleMemoryManager shuffleMemoryManager; + BlockManager blockManager; + DiskBlockManager diskBlockManager; + File tempDir; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + shuffleMemoryManager = mock(ShuffleMemoryManager.class); + diskBlockManager = mock(DiskBlockManager.class); + blockManager = mock(BlockManager.class); + tempDir = new File(Utils.createTempDir$default$1()); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { + final int[] arr = new int[] { value }; + sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); } /** @@ -49,88 +137,36 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset) */ @Test public void testSortingOnlyByPartitionId() throws Exception { - final String[] dataToSort = new String[] { - "Boba", - "Pearls", - "Tapioca", - "Taho", - "Condensed Milk", - "Jasmine", - "Milk Tea", - "Lychee", - "Mango" - }; - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock dataPage = memoryManager.allocatePage(2048); - final Object baseObject = dataPage.getBaseObject(); - // Write the records into the data page: - long position = dataPage.getBaseOffset(); - for (String str : dataToSort) { - final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); - position += 8; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); - position += strBytes.length; - } - // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so - // use a dummy comparator - final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { - @Override - public int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset) { - return 0; - } - }; - // Compute key prefixes based on the records' partition ids - final HashPartitioner hashPartitioner = new HashPartitioner(4); - // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { - @Override - public int compare(long prefix1, long prefix2) { - return (int) prefix1 - (int) prefix2; - } - }; - final UnsafeSorter sorter = new UnsafeSorter( + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, + shuffleMemoryManager, + blockManager, recordComparator, prefixComparator, - dataToSort.length); - // Given a page of records, insert those records into the sorter one-by-one: - position = dataPage.getBaseOffset(); - for (int i = 0; i < dataToSort.length; i++) { - // position now points to the start of a record (which holds its length). - final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); - final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); - final String str = getStringFromDataPage(baseObject, position); - final int partitionId = hashPartitioner.getPartition(str); - sorter.insertRecord(address, partitionId); - position += 8 + recordLength; - } - final Iterator iter = sorter.getSortedIterator(); - int iterLength = 0; - long prevPrefix = -1; - Arrays.sort(dataToSort); - while (iter.hasNext()) { - final UnsafeSorter.RecordPointerAndKeyPrefix pointerAndPrefix = iter.next(); - final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); - final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); - final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); - Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); - Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " + - prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix); - prevPrefix = pointerAndPrefix.keyPrefix; - iterLength++; - } - Assert.assertEquals(dataToSort.length, iterLength); + 1024, + new SparkConf()); + + insertNumber(sorter, 5); + insertNumber(sorter, 1); + insertNumber(sorter, 3); + sorter.spill(); + insertNumber(sorter, 4); + insertNumber(sorter, 2); + + Iterator iter = + sorter.getSortedIterator(); + + Assert.assertEquals(1, iter.next().keyPrefix); + Assert.assertEquals(2, iter.next().keyPrefix); + Assert.assertEquals(3, iter.next().keyPrefix); + Assert.assertEquals(4, iter.next().keyPrefix); + Assert.assertEquals(5, iter.next().keyPrefix); + Assert.assertFalse(iter.hasNext()); + // TODO: check that the values are also read back properly. + + // TODO: test for cleanup: + // assert(tempDir.isEmpty) } } diff --git a/pom.xml b/pom.xml index c85c5feeaf383..57d857340f735 100644 --- a/pom.xml +++ b/pom.xml @@ -652,7 +652,7 @@ org.mockito mockito-all - 1.9.0 + 1.9.5 test From f156a8f830af6867015324592801255a2bbcb17b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 23:23:14 -0700 Subject: [PATCH 27/92] Hacky metrics integration; refactor some interfaces. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 24 +++--- .../unsafe/sort/ExternalSorterIterator.java | 31 ++++++++ .../unsafe/sort/UnsafeExternalSorter.java | 76 ++++++++++++++++--- .../spark/unsafe/sort/UnsafeSorter.java | 23 ++++-- ...rger.java => UnsafeSorterSpillMerger.java} | 8 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 2 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 6 +- .../sort/UnsafeExternalSorterSuite.java | 26 ++++--- 8 files changed, 151 insertions(+), 45 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java rename core/src/main/java/org/apache/spark/unsafe/sort/{UnsafeExternalSortSpillMerger.java => UnsafeSorterSpillMerger.java} (94%) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 0ea11e823d1d4..d142cf59d8085 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -18,7 +18,7 @@ package org.apache.spark.shuffle.unsafe; import org.apache.spark.*; -import org.apache.spark.unsafe.sort.UnsafeExternalSortSpillMerger; +import org.apache.spark.unsafe.sort.ExternalSorterIterator; import org.apache.spark.unsafe.sort.UnsafeExternalSorter; import scala.Option; import scala.Product2; @@ -28,7 +28,6 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.Iterator; import java.util.LinkedList; import com.esotericsoftware.kryo.io.ByteBufferOutputStream; @@ -47,7 +46,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.unsafe.sort.UnsafeSorter; + import static org.apache.spark.unsafe.sort.UnsafeSorter.*; // IntelliJ gets confused and claims that this class should be abstract, but this actually compiles @@ -64,7 +63,6 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; - private final LinkedList allocatedPages = new LinkedList(); private final int fileBufferSize; private MapStatus mapStatus = null; @@ -108,12 +106,13 @@ private void freeMemory() { // TODO: free sorter memory } - private Iterator sortRecords( + private ExternalSorterIterator sortRecords( scala.collection.Iterator> records) throws Exception { final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, SparkEnv$.MODULE$.get().shuffleMemoryManager(), SparkEnv$.MODULE$.get().blockManager(), + TaskContext.get(), RECORD_COMPARATOR, PREFIX_COMPARATOR, 4096, // Initial size (TODO: tune this!) @@ -145,8 +144,7 @@ private Iterator sortRe return sorter.getSortedIterator(); } - private long[] writeSortedRecordsToFile( - Iterator sortedRecords) throws IOException { + private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); final ShuffleBlockId blockId = new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); @@ -157,8 +155,8 @@ private long[] writeSortedRecordsToFile( final byte[] arr = new byte[SER_BUFFER_SIZE]; while (sortedRecords.hasNext()) { - final UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix recordPointer = sortedRecords.next(); - final int partition = (int) recordPointer.keyPrefix; + sortedRecords.loadNext(); + final int partition = (int) sortedRecords.keyPrefix; assert (partition >= currentPartition); if (partition != currentPartition) { // Switch to the new partition @@ -172,13 +170,13 @@ private long[] writeSortedRecordsToFile( } PlatformDependent.copyMemory( - recordPointer.baseObject, - recordPointer.baseOffset + 4, + sortedRecords.baseObject, + sortedRecords.baseOffset + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, - recordPointer.recordLength); + sortedRecords.recordLength); assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, recordPointer.recordLength); + writer.write(arr, 0, sortedRecords.recordLength); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java b/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java new file mode 100644 index 0000000000000..d53a0baaf351f --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +public abstract class ExternalSorterIterator { + + public Object baseObject; + public long baseOffset; + public int recordLength; + public long keyPrefix; + + public abstract boolean hasNext(); + + public abstract void loadNext(); + +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java index 42669055bcf1c..bf0019c51703f 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java @@ -19,12 +19,15 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Iterator; @@ -37,16 +40,20 @@ */ public final class UnsafeExternalSorter { + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; private final int initialSize; + private int numSpills = 0; private UnsafeSorter sorter; private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; + private final TaskContext taskContext; private final LinkedList allocatedPages = new LinkedList(); private final boolean spillingEnabled; private final int fileBufferSize; @@ -63,13 +70,15 @@ public UnsafeExternalSorter( TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, + TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, - SparkConf conf) { + SparkConf conf) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; + this.taskContext = taskContext; this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; this.initialSize = initialSize; @@ -81,9 +90,19 @@ public UnsafeExternalSorter( // TODO: metrics tracking + integration with shuffle write metrics - private void openSorter() { + private void openSorter() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); // TODO: connect write metrics to task metrics? + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L * 2; + if (spillingEnabled) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire memory!"); + } + } + this.sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComparator, initialSize); } @@ -101,23 +120,52 @@ public void spill() throws IOException { spillWriter.write(baseObject, baseOffset, recordLength, recordPointer.keyPrefix); } spillWriter.close(); + final long sorterMemoryUsage = sorter.getMemoryUsage(); sorter = null; - freeMemory(); + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(spillWriter.numberOfSpilledBytes()); + numSpills++; + final long threadId = Thread.currentThread().getId(); + // TODO: messy; log _before_ spill + logger.info("Thread " + threadId + " spilling in-memory map of " + + org.apache.spark.util.Utils.bytesToString(spillSize) + " to disk (" + + (numSpills + ((numSpills > 1) ? " times" : " time")) + " so far)"); openSorter(); } - private void freeMemory() { + private long freeMemory() { + long memoryFreed = 0; final Iterator iter = allocatedPages.iterator(); while (iter.hasNext()) { memoryManager.freePage(iter.next()); shuffleMemoryManager.release(PAGE_SIZE); + memoryFreed += PAGE_SIZE; iter.remove(); } currentPage = null; currentPagePosition = -1; + return memoryFreed; } private void ensureSpaceInDataPage(int requiredSpace) throws Exception { + // TODO: merge these steps to first calculate total memory requirements for this insert, + // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the + // data page. + if (!sorter.hasSpaceForAnotherRecord() && spillingEnabled) { + final long oldSortBufferMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowSortBuffer); + if (memoryAcquired < memoryToGrowSortBuffer) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandSortBuffer(); + shuffleMemoryManager.release(oldSortBufferMemoryUsage); + } + } + final long spaceInCurrentPage; if (currentPage != null) { spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); @@ -129,12 +177,22 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception { throw new Exception("Required space " + requiredSpace + " is greater than page size (" + PAGE_SIZE + ")"); } else if (requiredSpace > spaceInCurrentPage) { - if (spillingEnabled && shuffleMemoryManager.tryToAcquire(PAGE_SIZE) < PAGE_SIZE) { - spill(); + if (spillingEnabled) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpill != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpill); + throw new Exception("Can't allocate memory!"); + } + } } currentPage = memoryManager.allocatePage(PAGE_SIZE); currentPagePosition = currentPage.getBaseOffset(); allocatedPages.add(currentPage); + logger.info("Acquired new page! " + allocatedPages.size() * PAGE_SIZE); } } @@ -162,9 +220,9 @@ public void insertRecord( sorter.insertRecord(recordAddress, prefix); } - public Iterator getSortedIterator() throws IOException { - final UnsafeExternalSortSpillMerger spillMerger = - new UnsafeExternalSortSpillMerger(recordComparator, prefixComparator); + public ExternalSorterIterator getSortedIterator() throws IOException { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpill(spillWriter.getReader(blockManager)); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 0f844c1997668..917cbdb564a15 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -86,10 +86,9 @@ public static abstract class PrefixComparator { */ private int sortBufferInsertPosition = 0; - private void expandSortBuffer(int newSize) { - assert (newSize > sortBuffer.length); + public void expandSortBuffer() { final long[] oldBuffer = sortBuffer; - sortBuffer = new long[newSize]; + sortBuffer = new long[oldBuffer.length * 2]; System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); } @@ -122,14 +121,22 @@ public int compare(RecordPointerAndKeyPrefix left, RecordPointerAndKeyPrefix rig }; } + public long getMemoryUsage() { + return sortBuffer.length * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return sortBufferInsertPosition + 2 < sortBuffer.length; + } + /** * Insert a record into the sort buffer. * * @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}. */ public void insertRecord(long objectAddress, long keyPrefix) { - if (sortBufferInsertPosition + 2 == sortBuffer.length) { - expandSortBuffer(sortBuffer.length * 2); + if (!hasSpaceForAnotherRecord()) { + expandSortBuffer(); } sortBuffer[sortBufferInsertPosition] = objectAddress; sortBufferInsertPosition++; @@ -167,10 +174,10 @@ public void remove() { }; } - public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() { + public UnsafeSorterSpillMerger.MergeableIterator getMergeableIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - UnsafeExternalSortSpillMerger.MergeableIterator iter = - new UnsafeExternalSortSpillMerger.MergeableIterator() { + UnsafeSorterSpillMerger.MergeableIterator iter = + new UnsafeSorterSpillMerger.MergeableIterator() { private int position = 0; private Object baseObject; diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java similarity index 94% rename from core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java rename to core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java index c6bd4ee9df4ff..93278d5a26473 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java @@ -23,7 +23,7 @@ import static org.apache.spark.unsafe.sort.UnsafeSorter.*; -public final class UnsafeExternalSortSpillMerger { +public final class UnsafeSorterSpillMerger { private final PriorityQueue priorityQueue; @@ -46,9 +46,9 @@ public static final class RecordAddressAndKeyPrefix { public long keyPrefix; } - public UnsafeExternalSortSpillMerger( - final RecordComparator recordComparator, - final UnsafeSorter.PrefixComparator prefixComparator) { + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final UnsafeSorter.PrefixComparator prefixComparator) { final Comparator comparator = new Comparator() { @Override diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java index 7c696240aaa73..894a593d41f3e 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -24,7 +24,7 @@ import java.io.*; -public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger.MergeableIterator { +final class UnsafeSorterSpillReader extends UnsafeSorterSpillMerger.MergeableIterator { private final File file; private InputStream in; diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java index e0649122ac09c..6085df67d2c2e 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java @@ -33,7 +33,7 @@ import java.io.*; import java.nio.ByteBuffer; -public final class UnsafeSorterSpillWriter { +final class UnsafeSorterSpillWriter { private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this public static final int EOF_MARKER = -1; @@ -122,6 +122,10 @@ public void close() throws IOException { arr = null; } + public long numberOfSpilledBytes() { + return file.length(); + } + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { return new UnsafeSorterSpillReader(blockManager, file, blockId); } diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java index 4f2aa9b895c01..e745074af075c 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -20,6 +20,8 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContextImpl; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; @@ -41,7 +43,6 @@ import java.io.File; import java.io.InputStream; import java.io.OutputStream; -import java.nio.file.Files; import java.util.Iterator; import java.util.UUID; @@ -78,6 +79,7 @@ public int compare( BlockManager blockManager; DiskBlockManager diskBlockManager; File tempDir; + TaskContext taskContext; private static final class CompressStream extends AbstractFunction1 { @Override @@ -92,6 +94,7 @@ public void setUp() { diskBlockManager = mock(DiskBlockManager.class); blockManager = mock(BlockManager.class); tempDir = new File(Utils.createTempDir$default$1()); + taskContext = mock(TaskContext.class); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { @@ -142,6 +145,7 @@ public void testSortingOnlyByPartitionId() throws Exception { memoryManager, shuffleMemoryManager, blockManager, + taskContext, recordComparator, prefixComparator, 1024, @@ -154,14 +158,18 @@ public void testSortingOnlyByPartitionId() throws Exception { insertNumber(sorter, 4); insertNumber(sorter, 2); - Iterator iter = - sorter.getSortedIterator(); - - Assert.assertEquals(1, iter.next().keyPrefix); - Assert.assertEquals(2, iter.next().keyPrefix); - Assert.assertEquals(3, iter.next().keyPrefix); - Assert.assertEquals(4, iter.next().keyPrefix); - Assert.assertEquals(5, iter.next().keyPrefix); + ExternalSorterIterator iter = sorter.getSortedIterator(); + + iter.loadNext(); + Assert.assertEquals(1, iter.keyPrefix); + iter.loadNext(); + Assert.assertEquals(2, iter.keyPrefix); + iter.loadNext(); + Assert.assertEquals(3, iter.keyPrefix); + iter.loadNext(); + Assert.assertEquals(4, iter.keyPrefix); + iter.loadNext(); + Assert.assertEquals(5, iter.keyPrefix); Assert.assertFalse(iter.hasNext()); // TODO: check that the values are also read back properly. From 3490512dfe1d26c51a535464363f4f6fe110c084 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 May 2015 12:33:19 -0700 Subject: [PATCH 28/92] Misc. cleanup --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 20 ++++++------- .../spark/unsafe/sort/UnsafeSorter.java | 4 +-- .../unsafe/sort/UnsafeSorterSpillMerger.java | 30 +++++-------------- .../unsafe/sort/UnsafeSorterSpillReader.java | 5 ++-- .../unsafe/sort/UnsafeSorterSpillWriter.java | 25 +++++++++------- 5 files changed, 34 insertions(+), 50 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index d142cf59d8085..1c875a15687c1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -17,21 +17,18 @@ package org.apache.spark.shuffle.unsafe; -import org.apache.spark.*; -import org.apache.spark.unsafe.sort.ExternalSorterIterator; -import org.apache.spark.unsafe.sort.UnsafeExternalSorter; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; + import scala.Option; import scala.Product2; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.LinkedList; - import com.esotericsoftware.kryo.io.ByteBufferOutputStream; +import org.apache.spark.*; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -44,10 +41,11 @@ import org.apache.spark.storage.BlockObjectWriter; import org.apache.spark.storage.ShuffleBlockId; import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; - -import static org.apache.spark.unsafe.sort.UnsafeSorter.*; +import org.apache.spark.unsafe.sort.ExternalSorterIterator; +import org.apache.spark.unsafe.sort.UnsafeExternalSorter; +import static org.apache.spark.unsafe.sort.UnsafeSorter.PrefixComparator; +import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordComparator; // IntelliJ gets confused and claims that this class should be abstract, but this actually compiles public class UnsafeShuffleWriter implements ShuffleWriter { diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 917cbdb564a15..c69d486705f65 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -176,8 +176,7 @@ public void remove() { public UnsafeSorterSpillMerger.MergeableIterator getMergeableIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - UnsafeSorterSpillMerger.MergeableIterator iter = - new UnsafeSorterSpillMerger.MergeableIterator() { + return new UnsafeSorterSpillMerger.MergeableIterator() { private int position = 0; private Object baseObject; @@ -213,6 +212,5 @@ public long getBaseOffset() { return baseOffset; } }; - return iter; } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java index 93278d5a26473..bd3f4424724f6 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java @@ -18,12 +18,11 @@ package org.apache.spark.unsafe.sort; import java.util.Comparator; -import java.util.Iterator; import java.util.PriorityQueue; import static org.apache.spark.unsafe.sort.UnsafeSorter.*; -public final class UnsafeSorterSpillMerger { +final class UnsafeSorterSpillMerger { private final PriorityQueue priorityQueue; @@ -39,13 +38,6 @@ public static abstract class MergeableIterator { public abstract long getBaseOffset(); } - public static final class RecordAddressAndKeyPrefix { - public Object baseObject; - public long baseOffset; - public int recordLength; - public long keyPrefix; - } - public UnsafeSorterSpillMerger( final RecordComparator recordComparator, final UnsafeSorter.PrefixComparator prefixComparator) { @@ -74,11 +66,10 @@ public void addSpill(MergeableIterator spillReader) { priorityQueue.add(spillReader); } - public Iterator getSortedIterator() { - return new Iterator() { + public ExternalSorterIterator getSortedIterator() { + return new ExternalSorterIterator() { private MergeableIterator spillReader; - private final RecordAddressAndKeyPrefix record = new RecordAddressAndKeyPrefix(); @Override public boolean hasNext() { @@ -86,7 +77,7 @@ public boolean hasNext() { } @Override - public RecordAddressAndKeyPrefix next() { + public void loadNext() { if (spillReader != null) { if (spillReader.hasNext()) { spillReader.loadNextRecord(); @@ -94,17 +85,10 @@ public RecordAddressAndKeyPrefix next() { } } spillReader = priorityQueue.remove(); - record.baseObject = spillReader.getBaseObject(); - record.baseOffset = spillReader.getBaseOffset(); - record.keyPrefix = spillReader.getPrefix(); - return record; - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); + baseObject = spillReader.getBaseObject(); + baseOffset = spillReader.getBaseOffset(); + keyPrefix = spillReader.getPrefix(); } }; } - } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java index 894a593d41f3e..3102e5ab3b6f4 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -17,13 +17,14 @@ package org.apache.spark.unsafe.sort; +import java.io.*; + import com.google.common.io.ByteStreams; + import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.PlatformDependent; -import java.io.*; - final class UnsafeSorterSpillReader extends UnsafeSorterSpillMerger.MergeableIterator { private final File file; diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java index 6085df67d2c2e..33356c3351967 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java @@ -17,9 +17,14 @@ package org.apache.spark.unsafe.sort; +import java.io.*; +import java.nio.ByteBuffer; + +import scala.Tuple2; +import scala.reflect.ClassTag; + import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.JavaSerializerInstance; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockId; @@ -27,27 +32,23 @@ import org.apache.spark.storage.BlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; -import scala.Tuple2; -import scala.reflect.ClassTag; - -import java.io.*; -import java.nio.ByteBuffer; final class UnsafeSorterSpillWriter { private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this - public static final int EOF_MARKER = -1; - byte[] arr = new byte[SER_BUFFER_SIZE]; + static final int EOF_MARKER = -1; + + private byte[] arr = new byte[SER_BUFFER_SIZE]; private final File file; private final BlockId blockId; - BlockObjectWriter writer; - DataOutputStream dos; + private BlockObjectWriter writer; + private DataOutputStream dos; public UnsafeSorterSpillWriter( BlockManager blockManager, int fileBufferSize, - ShuffleWriteMetrics writeMetrics) throws IOException { + ShuffleWriteMetrics writeMetrics) { final Tuple2 spilledFileInfo = blockManager.diskBlockManager().createTempLocalBlock(); this.file = spilledFileInfo._2(); @@ -119,6 +120,8 @@ public void write( public void close() throws IOException { dos.writeInt(EOF_MARKER); writer.commitAndClose(); + writer = null; + dos = null; arr = null; } From 3aeaff7599209dd47453ca8f109d5ccdf7b76b21 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 May 2015 14:50:07 -0700 Subject: [PATCH 29/92] More refactoring and cleanup; begin cleaning iterator interfaces --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 21 +++-- ...terIterator.java => PrefixComparator.java} | 17 ++-- .../spark/unsafe/sort/RecordComparator.java | 37 ++++++++ .../unsafe/sort/UnsafeExternalSorter.java | 19 ++-- .../unsafe/sort/UnsafeSortDataFormat.java | 15 ++- .../spark/unsafe/sort/UnsafeSorter.java | 93 +++---------------- .../unsafe/sort/UnsafeSorterIterator.java | 35 +++++++ .../unsafe/sort/UnsafeSorterSpillMerger.java | 57 ++++++------ .../unsafe/sort/UnsafeSorterSpillReader.java | 42 +++++---- .../sort/UnsafeExternalSorterSuite.java | 20 ++-- .../spark/unsafe/sort/UnsafeSorterSuite.java | 26 +++--- 11 files changed, 197 insertions(+), 185 deletions(-) rename core/src/main/java/org/apache/spark/unsafe/sort/{ExternalSorterIterator.java => PrefixComparator.java} (76%) create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 1c875a15687c1..01fe022fad046 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -42,10 +42,11 @@ import org.apache.spark.storage.ShuffleBlockId; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.unsafe.sort.ExternalSorterIterator; +import org.apache.spark.unsafe.sort.UnsafeSorterIterator; import org.apache.spark.unsafe.sort.UnsafeExternalSorter; -import static org.apache.spark.unsafe.sort.UnsafeSorter.PrefixComparator; -import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordComparator; +import org.apache.spark.unsafe.sort.PrefixComparator; + +import org.apache.spark.unsafe.sort.RecordComparator; // IntelliJ gets confused and claims that this class should be abstract, but this actually compiles public class UnsafeShuffleWriter implements ShuffleWriter { @@ -104,7 +105,7 @@ private void freeMemory() { // TODO: free sorter memory } - private ExternalSorterIterator sortRecords( + private UnsafeSorterIterator sortRecords( scala.collection.Iterator> records) throws Exception { final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, @@ -142,7 +143,7 @@ private ExternalSorterIterator sortRecords( return sorter.getSortedIterator(); } - private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) throws IOException { + private long[] writeSortedRecordsToFile(UnsafeSorterIterator sortedRecords) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); final ShuffleBlockId blockId = new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); @@ -154,7 +155,7 @@ private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) th final byte[] arr = new byte[SER_BUFFER_SIZE]; while (sortedRecords.hasNext()) { sortedRecords.loadNext(); - final int partition = (int) sortedRecords.keyPrefix; + final int partition = (int) sortedRecords.getKeyPrefix(); assert (partition >= currentPartition); if (partition != currentPartition) { // Switch to the new partition @@ -168,13 +169,13 @@ private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) th } PlatformDependent.copyMemory( - sortedRecords.baseObject, - sortedRecords.baseOffset + 4, + sortedRecords.getBaseObject(), + sortedRecords.getBaseOffset() + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, - sortedRecords.recordLength); + sortedRecords.getRecordLength()); assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, sortedRecords.recordLength); + writer.write(arr, 0, sortedRecords.getRecordLength()); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java b/core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java similarity index 76% rename from core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java rename to core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java index d53a0baaf351f..a8468d8b1cdb9 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java @@ -17,15 +17,10 @@ package org.apache.spark.unsafe.sort; -public abstract class ExternalSorterIterator { - - public Object baseObject; - public long baseOffset; - public int recordLength; - public long keyPrefix; - - public abstract boolean hasNext(); - - public abstract void loadNext(); - +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java new file mode 100644 index 0000000000000..99a2b077f9869 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { + + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java index bf0019c51703f..f455f471e1d13 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java @@ -33,8 +33,6 @@ import java.util.Iterator; import java.util.LinkedList; -import static org.apache.spark.unsafe.sort.UnsafeSorter.*; - /** * External sorter based on {@link UnsafeSorter}. */ @@ -111,13 +109,16 @@ public void spill() throws IOException { final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); spillWriters.add(spillWriter); - final Iterator sortedRecords = sorter.getSortedIterator(); + final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); while (sortedRecords.hasNext()) { - final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); - final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); - final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + // TODO: this assumption that the first long holds a length is not enforced via our interfaces + // We need to either always store this via the write path (e.g. not require the caller to do + // it), or provide interfaces / hooks for customizing the physical storage format etc. final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); - spillWriter.write(baseObject, baseOffset, recordLength, recordPointer.keyPrefix); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); } spillWriter.close(); final long sorterMemoryUsage = sorter.getMemoryUsage(); @@ -220,14 +221,14 @@ public void insertRecord( sorter.insertRecord(recordAddress, prefix); } - public ExternalSorterIterator getSortedIterator() throws IOException { + public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(recordComparator, prefixComparator); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpill(spillWriter.getReader(blockManager)); } spillWriters.clear(); - spillMerger.addSpill(sorter.getMergeableIterator()); + spillMerger.addSpill(sorter.getSortedIterator()); return spillMerger.getSortedIterator(); } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java index 290a87b70cad6..1e0a5c7bd1113 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.unsafe.sort; -import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordPointerAndKeyPrefix; import org.apache.spark.util.collection.SortDataFormat; +import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.RecordPointerAndKeyPrefix; /** * Supports sorting an array of (record pointer, key prefix) pairs. Used in {@link UnsafeSorter}. @@ -28,6 +28,19 @@ */ final class UnsafeSortDataFormat extends SortDataFormat { + static final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; + } + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); private UnsafeSortDataFormat() { } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index c69d486705f65..cfb85ea55bcd6 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -18,10 +18,10 @@ package org.apache.spark.unsafe.sort; import java.util.Comparator; -import java.util.Iterator; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.RecordPointerAndKeyPrefix; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records @@ -32,45 +32,6 @@ */ public final class UnsafeSorter { - public static final class RecordPointerAndKeyPrefix { - /** - * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a - * description of how these addresses are encoded. - */ - public long recordPointer; - - /** - * A key prefix, for use in comparisons. - */ - public long keyPrefix; - } - - /** - * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte - * prefix, this may simply return 0. - */ - public static abstract class RecordComparator { - /** - * Compare two records for order. - * - * @return a negative integer, zero, or a positive integer as the first record is less than, - * equal to, or greater than the second. - */ - public abstract int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset); - } - - /** - * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific - * comparisons, such as lexicographic comparison for strings. - */ - public static abstract class PrefixComparator { - public abstract int compare(long prefix1, long prefix2); - } - private final TaskMemoryManager memoryManager; private final Sorter sorter; private final Comparator sortComparator; @@ -148,40 +109,15 @@ public void insertRecord(long objectAddress, long keyPrefix) { * Return an iterator over record pointers in sorted order. For efficiency, all calls to * {@code next()} will return the same mutable object. */ - public Iterator getSortedIterator() { - sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new Iterator() { - private int position = 0; - private final RecordPointerAndKeyPrefix keyPointerAndPrefix = new RecordPointerAndKeyPrefix(); - - @Override - public boolean hasNext() { - return position < sortBufferInsertPosition; - } - - @Override - public RecordPointerAndKeyPrefix next() { - keyPointerAndPrefix.recordPointer = sortBuffer[position]; - keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1]; - position += 2; - return keyPointerAndPrefix; - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; - } - - public UnsafeSorterSpillMerger.MergeableIterator getMergeableIterator() { + public UnsafeSorterIterator getSortedIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new UnsafeSorterSpillMerger.MergeableIterator() { + return new UnsafeSorterIterator() { private int position = 0; private Object baseObject; private long baseOffset; private long keyPrefix; + private int recordLength; @Override public boolean hasNext() { @@ -189,28 +125,25 @@ public boolean hasNext() { } @Override - public void loadNextRecord() { + public void loadNext() { final long recordPointer = sortBuffer[position]; - keyPrefix = sortBuffer[position + 1]; - position += 2; baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer); + keyPrefix = sortBuffer[position + 1]; + position += 2; } @Override - public long getPrefix() { - return keyPrefix; - } + public Object getBaseObject() { return baseObject; } @Override - public Object getBaseObject() { - return baseObject; - } + public long getBaseOffset() { return baseOffset; } @Override - public long getBaseOffset() { - return baseOffset; - } + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } }; } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java new file mode 100644 index 0000000000000..8cac5887f9c3f --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import java.io.IOException; + +public abstract class UnsafeSorterIterator { + + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java index bd3f4424724f6..0837843ae442b 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java @@ -17,36 +17,23 @@ package org.apache.spark.unsafe.sort; +import java.io.IOException; import java.util.Comparator; import java.util.PriorityQueue; -import static org.apache.spark.unsafe.sort.UnsafeSorter.*; - final class UnsafeSorterSpillMerger { - private final PriorityQueue priorityQueue; - - public static abstract class MergeableIterator { - public abstract boolean hasNext(); - - public abstract void loadNextRecord(); - - public abstract long getPrefix(); - - public abstract Object getBaseObject(); - - public abstract long getBaseOffset(); - } + private final PriorityQueue priorityQueue; public UnsafeSorterSpillMerger( final RecordComparator recordComparator, - final UnsafeSorter.PrefixComparator prefixComparator) { - final Comparator comparator = new Comparator() { + final PrefixComparator prefixComparator) { + final Comparator comparator = new Comparator() { @Override - public int compare(MergeableIterator left, MergeableIterator right) { + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { final int prefixComparisonResult = - prefixComparator.compare(left.getPrefix(), right.getPrefix()); + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); if (prefixComparisonResult == 0) { return recordComparator.compare( left.getBaseObject(), left.getBaseOffset(), @@ -56,20 +43,21 @@ public int compare(MergeableIterator left, MergeableIterator right) { } } }; - priorityQueue = new PriorityQueue(10, comparator); + // TODO: the size is often known; incorporate size hints here. + priorityQueue = new PriorityQueue(10, comparator); } - public void addSpill(MergeableIterator spillReader) { + public void addSpill(UnsafeSorterIterator spillReader) throws IOException { if (spillReader.hasNext()) { - spillReader.loadNextRecord(); + spillReader.loadNext(); } priorityQueue.add(spillReader); } - public ExternalSorterIterator getSortedIterator() { - return new ExternalSorterIterator() { + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { - private MergeableIterator spillReader; + private UnsafeSorterIterator spillReader; @Override public boolean hasNext() { @@ -77,18 +65,27 @@ public boolean hasNext() { } @Override - public void loadNext() { + public void loadNext() throws IOException { if (spillReader != null) { if (spillReader.hasNext()) { - spillReader.loadNextRecord(); + spillReader.loadNext(); priorityQueue.add(spillReader); } } spillReader = priorityQueue.remove(); - baseObject = spillReader.getBaseObject(); - baseOffset = spillReader.getBaseOffset(); - keyPrefix = spillReader.getPrefix(); } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } }; } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java index 3102e5ab3b6f4..696bcd468b0a1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -25,16 +25,17 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.PlatformDependent; -final class UnsafeSorterSpillReader extends UnsafeSorterSpillMerger.MergeableIterator { +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { private final File file; private InputStream in; private DataInputStream din; - private long keyPrefix; private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? - private final Object baseObject = arr; private int nextRecordLength; + + private long keyPrefix; + private final Object baseObject = arr; private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( @@ -55,26 +56,17 @@ public boolean hasNext() { } @Override - public void loadNextRecord() { - try { - keyPrefix = din.readLong(); - ByteStreams.readFully(in, arr, 0, nextRecordLength); - nextRecordLength = din.readInt(); - if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { - in.close(); - in = null; - din = null; - } - } catch (Exception e) { - PlatformDependent.throwException(e); + public void loadNext() throws IOException { + keyPrefix = din.readLong(); + ByteStreams.readFully(in, arr, 0, nextRecordLength); + nextRecordLength = din.readInt(); + if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { + in.close(); + in = null; + din = null; } } - @Override - public long getPrefix() { - return keyPrefix; - } - @Override public Object getBaseObject() { return baseObject; @@ -84,4 +76,14 @@ public Object getBaseObject() { public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { + return 0; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } } diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java index e745074af075c..c33035c2c116b 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -21,8 +21,8 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; -import org.apache.spark.TaskContextImpl; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; @@ -43,7 +43,6 @@ import java.io.File; import java.io.InputStream; import java.io.OutputStream; -import java.util.Iterator; import java.util.UUID; import static org.mockito.Mockito.*; @@ -56,7 +55,7 @@ public class UnsafeExternalSorterSuite { // Compute key prefixes based on the records' partition ids final HashPartitioner hashPartitioner = new HashPartitioner(4); // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + final PrefixComparator prefixComparator = new PrefixComparator() { @Override public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; @@ -64,7 +63,7 @@ public int compare(long prefix1, long prefix2) { }; // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so // use a dummy comparator - final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + final RecordComparator recordComparator = new RecordComparator() { @Override public int compare( Object leftBaseObject, @@ -95,6 +94,7 @@ public void setUp() { blockManager = mock(BlockManager.class); tempDir = new File(Utils.createTempDir$default$1()); taskContext = mock(TaskContext.class); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { @@ -158,18 +158,18 @@ public void testSortingOnlyByPartitionId() throws Exception { insertNumber(sorter, 4); insertNumber(sorter, 2); - ExternalSorterIterator iter = sorter.getSortedIterator(); + UnsafeSorterIterator iter = sorter.getSortedIterator(); iter.loadNext(); - Assert.assertEquals(1, iter.keyPrefix); + Assert.assertEquals(1, iter.getKeyPrefix()); iter.loadNext(); - Assert.assertEquals(2, iter.keyPrefix); + Assert.assertEquals(2, iter.getKeyPrefix()); iter.loadNext(); - Assert.assertEquals(3, iter.keyPrefix); + Assert.assertEquals(3, iter.getKeyPrefix()); iter.loadNext(); - Assert.assertEquals(4, iter.keyPrefix); + Assert.assertEquals(4, iter.getKeyPrefix()); iter.loadNext(); - Assert.assertEquals(5, iter.keyPrefix); + Assert.assertEquals(5, iter.getKeyPrefix()); Assert.assertFalse(iter.hasNext()); // TODO: check that the values are also read back properly. diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index aed115f83a368..3a2e9696761de 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.sort; import java.util.Arrays; -import java.util.Iterator; import org.junit.Assert; import org.junit.Test; @@ -48,10 +47,10 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset) public void testSortingEmptyInput() { final UnsafeSorter sorter = new UnsafeSorter( new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), - mock(UnsafeSorter.RecordComparator.class), - mock(UnsafeSorter.PrefixComparator.class), + mock(RecordComparator.class), + mock(PrefixComparator.class), 100); - final Iterator iter = sorter.getSortedIterator(); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -91,7 +90,7 @@ public void testSortingOnlyByPartitionId() throws Exception { } // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so // use a dummy comparator - final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + final RecordComparator recordComparator = new RecordComparator() { @Override public int compare( Object leftBaseObject, @@ -104,7 +103,7 @@ public int compare( // Compute key prefixes based on the records' partition ids final HashPartitioner hashPartitioner = new HashPartitioner(4); // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + final PrefixComparator prefixComparator = new PrefixComparator() { @Override public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; @@ -123,19 +122,18 @@ public int compare(long prefix1, long prefix2) { sorter.insertRecord(address, partitionId); position += 8 + recordLength; } - final Iterator iter = sorter.getSortedIterator(); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; Arrays.sort(dataToSort); while (iter.hasNext()) { - final UnsafeSorter.RecordPointerAndKeyPrefix pointerAndPrefix = iter.next(); - final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); - final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); - final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); + iter.loadNext(); + final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset()); + final long keyPrefix = iter.getKeyPrefix(); Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); - Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " + - prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix); - prevPrefix = pointerAndPrefix.keyPrefix; + Assert.assertTrue("Prefix " + keyPrefix + " should be >= previous prefix " + + prevPrefix, keyPrefix >= prevPrefix); + prevPrefix = keyPrefix; iterLength++; } Assert.assertEquals(dataToSort.length, iterLength); From 7ee918e6c2f72bcfa1070026e5165f5b96d71f56 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 May 2015 17:27:37 -0700 Subject: [PATCH 30/92] Re-order imports in tests --- .../sort/UnsafeExternalSorterSuite.java | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java index c33035c2c116b..9cb96fa2c3322 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -17,6 +17,23 @@ package org.apache.spark.unsafe.sort; +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.UUID; + +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Mockito.*; import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; @@ -31,22 +48,6 @@ import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import scala.Tuple2; -import scala.Tuple2$; -import scala.runtime.AbstractFunction1; - -import java.io.File; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.UUID; - -import static org.mockito.Mockito.*; -import static org.mockito.AdditionalAnswers.*; public class UnsafeExternalSorterSuite { From 69232fdf0023ac3786f6ade34f20564ba30b000d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 May 2015 22:16:54 -0700 Subject: [PATCH 31/92] Enable compressible address encoding for off-heap mode. --- .../unsafe/memory/TaskMemoryManager.java | 15 +++++++------- .../unsafe/memory/TaskMemoryManagerSuite.java | 20 +++++++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 9224988e6ad69..983bd99f0c4ce 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -175,12 +175,8 @@ public void free(MemoryBlock memory) { * This address will remain valid as long as the corresponding page has not been freed. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (inHeap) { - assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); - } else { - return offsetInPage; - } + assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); } /** @@ -204,10 +200,13 @@ public Object getPage(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); if (inHeap) { - return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + return offsetInPage; } else { - return pagePlusOffsetAddress; + final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + return pageTable[pageNumber].getBaseOffset() + offsetInPage; } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java index 932882f1ca248..8ace8625abb64 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -38,4 +38,24 @@ public void leakedPageMemoryIsDetected() { Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } + @Test + public void encodePageNumberAndOffsetOffHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock dataPage = manager.allocatePage(256); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(dataPage.getBaseOffset() + 64, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void encodePageNumberAndOffsetOnHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = manager.allocatePage(256); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); + } + } From 57f1ec04f9e4444b034a4aefa47dcea1eca2603b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 May 2015 23:05:34 -0700 Subject: [PATCH 32/92] WIP towards packed record pointers for use in optimized shuffle sort. --- .../shuffle/unsafe/PackedRecordPointer.java | 74 +++++++++++++++++++ .../unsafe/PackedRecordPointerSuite.java | 57 ++++++++++++++ 2 files changed, 131 insertions(+) create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java create mode 100644 core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java new file mode 100644 index 0000000000000..34c15e6bbcb0e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +/** + * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. + */ +final class PackedRecordPointer { + + /** Bit mask for the lower 40 bits of a long. */ + private static final long MASK_LONG_LOWER_40_BITS = 0xFFFFFFFFFFL; + + /** Bit mask for the upper 24 bits of a long */ + private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS; + + /** Bit mask for the lower 27 bits of a long. */ + private static final long MASK_LONG_LOWER_27_BITS = 0x7FFFFFFL; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + // TODO: this shifting is probably extremely inefficient; this is just for prototyping + + /** + * Pack a record address and partition id into a single word. + * + * @param recordPointer a record pointer encoded by TaskMemoryManager. + * @param partitionId a shuffle partition id (maximum value of 2^24). + * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class. + */ + public static long packPointer(long recordPointer, int partitionId) { + // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. + // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. + final int pageNumber = (int) ((recordPointer & MASK_LONG_UPPER_13_BITS) >>> 51); + final long compressedAddress = + (((long) pageNumber) << 27) | (recordPointer & MASK_LONG_LOWER_27_BITS); + return (((long) partitionId) << 40) | compressedAddress; + } + + public long packedRecordPointer; + + public int getPartitionId() { + return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40); + } + + public long getRecordPointer() { + final long compressedAddress = packedRecordPointer & MASK_LONG_LOWER_40_BITS; + final long pageNumber = (compressedAddress << 24) & MASK_LONG_UPPER_13_BITS; + final long offsetInPage = compressedAddress & MASK_LONG_LOWER_27_BITS; + return pageNumber | offsetInPage; + } + + public int getRecordLength() { + return -1; // TODO + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java new file mode 100644 index 0000000000000..53554520b22b1 --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class PackedRecordPointerSuite { + + @Test + public void heap() { + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock page0 = memoryManager.allocatePage(100); + final MemoryBlock page1 = memoryManager.allocatePage(100); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); + PackedRecordPointer packedPointerWrapper = new PackedRecordPointer(); + packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360); + Assert.assertEquals(360, packedPointerWrapper.getPartitionId()); + Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer()); + memoryManager.cleanUpAllAllocatedMemory(); + } + + @Test + public void offHeap() { + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock page0 = memoryManager.allocatePage(100); + final MemoryBlock page1 = memoryManager.allocatePage(100); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); + PackedRecordPointer packedPointerWrapper = new PackedRecordPointer(); + packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360); + Assert.assertEquals(360, packedPointerWrapper.getPartitionId()); + Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer()); + memoryManager.cleanUpAllAllocatedMemory(); + } +} From f480fb2a7c363f70bb31e8784868f5c3ceb2a883 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 May 2015 23:20:10 -0700 Subject: [PATCH 33/92] WIP in mega-refactoring towards shuffle-specific sort. --- .../unsafe/DummySerializerInstance.java | 69 +++++++ .../unsafe/SpillInfo.java} | 25 ++- .../unsafe/UnsafeShuffleSortDataFormat.java | 68 +++++++ .../shuffle/unsafe/UnsafeShuffleSorter.java | 105 ++++++++++ .../unsafe/UnsafeShuffleSpillWriter.java} | 138 ++++++++----- .../shuffle/unsafe/UnsafeShuffleWriter.java | 114 +++++------ .../spark/unsafe/sort/PrefixComparator.java | 26 --- .../spark/unsafe/sort/RecordComparator.java | 37 ---- .../unsafe/sort/UnsafeSortDataFormat.java | 93 --------- .../spark/unsafe/sort/UnsafeSorter.java | 149 -------------- .../unsafe/sort/UnsafeSorterSpillMerger.java | 91 --------- .../unsafe/sort/UnsafeSorterSpillReader.java | 89 --------- .../unsafe/sort/UnsafeSorterSpillWriter.java | 135 ------------- .../unsafe/UnsafeShuffleSorterSuite.java | 110 +++++++++++ .../sort/UnsafeExternalSorterSuite.java | 181 ------------------ .../spark/unsafe/sort/UnsafeSorterSuite.java | 141 -------------- 16 files changed, 497 insertions(+), 1074 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java rename core/src/main/java/org/apache/spark/{unsafe/sort/UnsafeSorterIterator.java => shuffle/unsafe/SpillInfo.java} (68%) create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java rename core/src/main/java/org/apache/spark/{unsafe/sort/UnsafeExternalSorter.java => shuffle/unsafe/UnsafeShuffleSpillWriter.java} (65%) delete mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java delete mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java delete mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java delete mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java delete mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java delete mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java delete mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java create mode 100644 core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java delete mode 100644 core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java delete mode 100644 core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java new file mode 100644 index 0000000000000..ab174c3ca921a --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.apache.spark.serializer.DeserializationStream; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import scala.reflect.ClassTag; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +class DummySerializerInstance extends SerializerInstance { + @Override + public SerializationStream serializeStream(OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { + + } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + return null; + } + + @Override + public void close() { + + } + }; + } + + @Override + public ByteBuffer serialize(T t, ClassTag ev1) { + return null; + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag ev1) { + return null; + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java similarity index 68% rename from core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java rename to core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java index 8cac5887f9c3f..5e8c090405098 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -15,21 +15,20 @@ * limitations under the License. */ -package org.apache.spark.unsafe.sort; +package org.apache.spark.shuffle.unsafe; -import java.io.IOException; +import org.apache.spark.storage.BlockId; -public abstract class UnsafeSorterIterator { +import java.io.File; - public abstract boolean hasNext(); +final class SpillInfo { + final long[] partitionLengths; + final File file; + final BlockId blockId; - public abstract void loadNext() throws IOException; - - public abstract Object getBaseObject(); - - public abstract long getBaseOffset(); - - public abstract int getRecordLength(); - - public abstract long getKeyPrefix(); + public SpillInfo(int numPartitions, File file, BlockId blockId) { + this.partitionLengths = new long[numPartitions]; + this.file = file; + this.blockId = blockId; + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java new file mode 100644 index 0000000000000..d7afa1a906428 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.apache.spark.util.collection.SortDataFormat; + +final class UnsafeShuffleSortDataFormat extends SortDataFormat { + + public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + + private UnsafeShuffleSortDataFormat() { } + + @Override + public PackedRecordPointer getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public PackedRecordPointer newKey() { + return new PackedRecordPointer(); + } + + @Override + public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { + reuse.packedRecordPointer = data[pos]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + final long temp = data[pos0]; + data[pos0] = data[pos1]; + data[pos1] = temp; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos] = src[srcPos]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos, dst, dstPos, length); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE) : "Length " + length + " is too large"; + return new long[length]; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java new file mode 100644 index 0000000000000..eb46776efe12c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Comparator; + +import org.apache.spark.util.collection.Sorter; + +public final class UnsafeShuffleSorter { + + private final Sorter sorter; + private final Comparator sortComparator; + + private long[] sortBuffer; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int sortBufferInsertPosition = 0; + + public UnsafeShuffleSorter(int initialSize) { + assert (initialSize > 0); + this.sortBuffer = new long[initialSize]; + this.sorter = + new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + this.sortComparator = new Comparator() { + @Override + public int compare(PackedRecordPointer left, PackedRecordPointer right) { + return left.getPartitionId() - right.getPartitionId(); + } + }; + } + + public void expandSortBuffer() { + final long[] oldBuffer = sortBuffer; + sortBuffer = new long[oldBuffer.length * 2]; + System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); + } + + public boolean hasSpaceForAnotherRecord() { + return sortBufferInsertPosition + 1 < sortBuffer.length; + } + + public long getMemoryUsage() { + return sortBuffer.length * 8L; + } + + // TODO: clairify assumption that pointer points to record length. + public void insertRecord(long recordPointer, int partitionId) { + if (!hasSpaceForAnotherRecord()) { + expandSortBuffer(); + } + sortBuffer[sortBufferInsertPosition] = + PackedRecordPointer.packPointer(recordPointer, partitionId); + sortBufferInsertPosition++; + } + + public static abstract class UnsafeShuffleSorterIterator { + + final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); + + public abstract boolean hasNext(); + + public abstract void loadNext(); + + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public UnsafeShuffleSorterIterator getSortedIterator() { + sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator); + return new UnsafeShuffleSorterIterator() { + + private int position = 0; + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void loadNext() { + packedRecordPointer.packedRecordPointer = sortBuffer[position]; + position++; + } + }; + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java similarity index 65% rename from core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java rename to core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java index f455f471e1d13..b0e2b6022ef21 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java @@ -15,38 +15,43 @@ * limitations under the License. */ -package org.apache.spark.unsafe.sort; +package org.apache.spark.shuffle.unsafe; import com.google.common.annotations.VisibleForTesting; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.Tuple2; +import java.io.File; import java.io.IOException; import java.util.Iterator; import java.util.LinkedList; /** - * External sorter based on {@link UnsafeSorter}. + * External sorter based on {@link UnsafeShuffleSorter}. */ -public final class UnsafeExternalSorter { +public final class UnsafeShuffleSpillWriter { - private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleSpillWriter.class); + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this / don't duplicate private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this - private final PrefixComparator prefixComparator; - private final RecordComparator recordComparator; private final int initialSize; - private int numSpills = 0; - private UnsafeSorter sorter; + private final int numPartitions; + private UnsafeShuffleSorter sorter; private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; @@ -61,25 +66,22 @@ public final class UnsafeExternalSorter { private MemoryBlock currentPage = null; private long currentPagePosition = -1; - private final LinkedList spillWriters = - new LinkedList(); - - public UnsafeExternalSorter( - TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, - BlockManager blockManager, - TaskContext taskContext, - RecordComparator recordComparator, - PrefixComparator prefixComparator, - int initialSize, - SparkConf conf) throws IOException { + private final LinkedList spills = new LinkedList(); + + public UnsafeShuffleSpillWriter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; - this.recordComparator = recordComparator; - this.prefixComparator = prefixComparator; this.initialSize = initialSize; + this.numPartitions = numPartitions; this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true); // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; @@ -92,7 +94,7 @@ private void openSorter() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); // TODO: connect write metrics to task metrics? // TODO: move this sizing calculation logic into a static method of sorter: - final long memoryRequested = initialSize * 8L * 2; + final long memoryRequested = initialSize * 8L; if (spillingEnabled) { final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); if (memoryAcquired != memoryRequested) { @@ -101,38 +103,77 @@ private void openSorter() throws IOException { } } - this.sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComparator, initialSize); + this.sorter = new UnsafeShuffleSorter(initialSize); } - @VisibleForTesting - public void spill() throws IOException { - final UnsafeSorterSpillWriter spillWriter = - new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); - spillWriters.add(spillWriter); - final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); + private SpillInfo writeSpillFile() throws IOException { + final UnsafeShuffleSorter.UnsafeShuffleSorterIterator sortedRecords = sorter.getSortedIterator(); + + int currentPartition = -1; + BlockObjectWriter writer = null; + final byte[] arr = new byte[SER_BUFFER_SIZE]; + + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + final File file = spilledFileInfo._2(); + final BlockId blockId = spilledFileInfo._1(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + spills.add(spillInfo); + + final SerializerInstance ser = new DummySerializerInstance(); + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); + while (sortedRecords.hasNext()) { sortedRecords.loadNext(); - final Object baseObject = sortedRecords.getBaseObject(); - final long baseOffset = sortedRecords.getBaseOffset(); - // TODO: this assumption that the first long holds a length is not enforced via our interfaces - // We need to either always store this via the write path (e.g. not require the caller to do - // it), or provide interfaces / hooks for customizing the physical storage format etc. - final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); - spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + writer.commitAndClose(); + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + } + currentPartition = partition; + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); + } + + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final int recordLength = PlatformDependent.UNSAFE.getInt( + memoryManager.getPage(recordPointer), memoryManager.getOffsetInPage(recordPointer)); + PlatformDependent.copyMemory( + memoryManager.getPage(recordPointer), + memoryManager.getOffsetInPage(recordPointer) + 4, // skip over record length + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength); + assert (writer != null); // To suppress an IntelliJ warning + writer.write(arr, 0, recordLength); + // TODO: add a test that detects whether we leave this call out: + writer.recordWritten(); + } + + if (writer != null) { + writer.commitAndClose(); + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); } - spillWriter.close(); + return spillInfo; + } + + @VisibleForTesting + public void spill() throws IOException { + final SpillInfo spillInfo = writeSpillFile(); + final long sorterMemoryUsage = sorter.getMemoryUsage(); sorter = null; shuffleMemoryManager.release(sorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - taskContext.taskMetrics().incDiskBytesSpilled(spillWriter.numberOfSpilledBytes()); - numSpills++; + taskContext.taskMetrics().incDiskBytesSpilled(spillInfo.file.length()); final long threadId = Thread.currentThread().getId(); // TODO: messy; log _before_ spill logger.info("Thread " + threadId + " spilling in-memory map of " + org.apache.spark.util.Utils.bytesToString(spillSize) + " to disk (" + - (numSpills + ((numSpills > 1) ? " times" : " time")) + " so far)"); + (spills.size() + ((spills.size() > 1) ? " times" : " time")) + " so far)"); openSorter(); } @@ -201,7 +242,7 @@ public void insertRecord( Object recordBaseObject, long recordBaseOffset, int lengthInBytes, - long prefix) throws Exception { + int prefix) throws Exception { // Need 4 bytes to store the record length. ensureSpaceInDataPage(lengthInBytes + 4); @@ -221,14 +262,11 @@ public void insertRecord( sorter.insertRecord(recordAddress, prefix); } - public UnsafeSorterIterator getSortedIterator() throws IOException { - final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator); - for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - spillMerger.addSpill(spillWriter.getReader(blockManager)); + public SpillInfo[] closeAndGetSpills() throws IOException { + if (sorter != null) { + writeSpillFile(); } - spillWriters.clear(); - spillMerger.addSpill(sorter.getSortedIterator()); - return spillMerger.getSortedIterator(); + return (SpillInfo[]) spills.toArray(); } + } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 01fe022fad046..839c854963ccf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -18,8 +18,11 @@ package org.apache.spark.shuffle.unsafe; import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; import scala.Option; import scala.Product2; @@ -38,15 +41,8 @@ import org.apache.spark.shuffle.IndexShuffleBlockManager; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; -import org.apache.spark.storage.ShuffleBlockId; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.unsafe.sort.UnsafeSorterIterator; -import org.apache.spark.unsafe.sort.UnsafeExternalSorter; -import org.apache.spark.unsafe.sort.PrefixComparator; - -import org.apache.spark.unsafe.sort.RecordComparator; // IntelliJ gets confused and claims that this class should be abstract, but this actually compiles public class UnsafeShuffleWriter implements ShuffleWriter { @@ -62,7 +58,6 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; - private final int fileBufferSize; private MapStatus mapStatus = null; /** @@ -86,14 +81,11 @@ public UnsafeShuffleWriter( this.partitioner = dep.partitioner(); this.writeMetrics = new ShuffleWriteMetrics(); context.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); - this.fileBufferSize = - // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units - (int) SparkEnv.get().conf().getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; } public void write(scala.collection.Iterator> records) { try { - final long[] partitionLengths = writeSortedRecordsToFile(sortRecords(records)); + final long[] partitionLengths = mergeSpills(insertRecordsIntoSorter(records)); shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } catch (Exception e) { @@ -102,19 +94,18 @@ public void write(scala.collection.Iterator> records) { } private void freeMemory() { - // TODO: free sorter memory + // TODO } - private UnsafeSorterIterator sortRecords( - scala.collection.Iterator> records) throws Exception { - final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + private SpillInfo[] insertRecordsIntoSorter( + scala.collection.Iterator> records) throws Exception { + final UnsafeShuffleSpillWriter sorter = new UnsafeShuffleSpillWriter( memoryManager, SparkEnv$.MODULE$.get().shuffleMemoryManager(), SparkEnv$.MODULE$.get().blockManager(), TaskContext.get(), - RECORD_COMPARATOR, - PREFIX_COMPARATOR, 4096, // Initial size (TODO: tune this!) + partitioner.numPartitions(), SparkEnv$.MODULE$.get().conf() ); @@ -140,50 +131,50 @@ private UnsafeSorterIterator sortRecords( serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } - return sorter.getSortedIterator(); + return sorter.closeAndGetSpills(); } - private long[] writeSortedRecordsToFile(UnsafeSorterIterator sortedRecords) throws IOException { + private long[] mergeSpills(SpillInfo[] spills) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); - final ShuffleBlockId blockId = - new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); - final long[] partitionLengths = new long[partitioner.numPartitions()]; - - int currentPartition = -1; - BlockObjectWriter writer = null; - - final byte[] arr = new byte[SER_BUFFER_SIZE]; - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final int partition = (int) sortedRecords.getKeyPrefix(); - assert (partition >= currentPartition); - if (partition != currentPartition) { - // Switch to the new partition - if (currentPartition != -1) { - writer.commitAndClose(); - partitionLengths[currentPartition] = writer.fileSegment().length(); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + + // TODO: We need to add an option to bypass transferTo here since older Linux kernels are + // affected by a bug here that can lead to data truncation; see the comments Utils.scala, + // in the copyStream() method. I didn't use copyStream() here because we only want to copy + // a limited number of bytes from the stream and I didn't want to modify / extend that method + // to accept a length. + + // TODO: special case optimization for case where we only write one file (non-spill case). + + for (int i = 0; i < spills.length; i++) { + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + } + + final FileChannel mergedFileOutputChannel = new FileOutputStream(outputFile).getChannel(); + + for (int partition = 0; partition < numPartitions; partition++ ) { + for (int i = 0; i < spills.length; i++) { + final long bytesToTransfer = spills[i].partitionLengths[partition]; + long bytesRemainingToBeTransferred = bytesToTransfer; + final FileChannel spillInputChannel = spillInputChannels[i]; + long fromPosition = spillInputChannel.position(); + while (bytesRemainingToBeTransferred > 0) { + bytesRemainingToBeTransferred -= spillInputChannel.transferTo( + fromPosition, + bytesRemainingToBeTransferred, + mergedFileOutputChannel); } - currentPartition = partition; - writer = - blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics); + partitionLengths[partition] += bytesToTransfer; } - - PlatformDependent.copyMemory( - sortedRecords.getBaseObject(), - sortedRecords.getBaseOffset() + 4, - arr, - PlatformDependent.BYTE_ARRAY_OFFSET, - sortedRecords.getRecordLength()); - assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, sortedRecords.getRecordLength()); - // TODO: add a test that detects whether we leave this call out: - writer.recordWritten(); } - if (writer != null) { - writer.commitAndClose(); - partitionLengths[currentPartition] = writer.fileSegment().length(); + // TODO: should this be in a finally block? + for (int i = 0; i < spills.length; i++) { + spillInputChannels[i].close(); } + mergedFileOutputChannel.close(); return partitionLengths; } @@ -209,19 +200,4 @@ public Option stop(boolean success) { // TODO: increment the shuffle write time metrics } } - - private static final RecordComparator RECORD_COMPARATOR = new RecordComparator() { - @Override - public int compare( - Object leftBaseObject, long leftBaseOffset, Object rightBaseObject, long rightBaseOffset) { - return 0; - } - }; - - private static final PrefixComparator PREFIX_COMPARATOR = new PrefixComparator() { - @Override - public int compare(long prefix1, long prefix2) { - return (int) (prefix1 - prefix2); - } - }; } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java deleted file mode 100644 index a8468d8b1cdb9..0000000000000 --- a/core/src/main/java/org/apache/spark/unsafe/sort/PrefixComparator.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -/** - * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific - * comparisons, such as lexicographic comparison for strings. - */ -public abstract class PrefixComparator { - public abstract int compare(long prefix1, long prefix2); -} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java deleted file mode 100644 index 99a2b077f9869..0000000000000 --- a/core/src/main/java/org/apache/spark/unsafe/sort/RecordComparator.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -/** - * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte - * prefix, this may simply return 0. - */ -public abstract class RecordComparator { - - /** - * Compare two records for order. - * - * @return a negative integer, zero, or a positive integer as the first record is less than, - * equal to, or greater than the second. - */ - public abstract int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset); -} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java deleted file mode 100644 index 1e0a5c7bd1113..0000000000000 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -import org.apache.spark.util.collection.SortDataFormat; -import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.RecordPointerAndKeyPrefix; - -/** - * Supports sorting an array of (record pointer, key prefix) pairs. Used in {@link UnsafeSorter}. - * - * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at - * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. - */ -final class UnsafeSortDataFormat extends SortDataFormat { - - static final class RecordPointerAndKeyPrefix { - /** - * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a - * description of how these addresses are encoded. - */ - public long recordPointer; - - /** - * A key prefix, for use in comparisons. - */ - public long keyPrefix; - } - - public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); - - private UnsafeSortDataFormat() { } - - @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { - // Since we re-use keys, this method shouldn't be called. - throw new UnsupportedOperationException(); - } - - @Override - public RecordPointerAndKeyPrefix newKey() { - return new RecordPointerAndKeyPrefix(); - } - - @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { - reuse.recordPointer = data[pos * 2]; - reuse.keyPrefix = data[pos * 2 + 1]; - return reuse; - } - - @Override - public void swap(long[] data, int pos0, int pos1) { - long tempPointer = data[pos0 * 2]; - long tempKeyPrefix = data[pos0 * 2 + 1]; - data[pos0 * 2] = data[pos1 * 2]; - data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; - data[pos1 * 2] = tempPointer; - data[pos1 * 2 + 1] = tempKeyPrefix; - } - - @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos * 2] = src[srcPos * 2]; - dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; - } - - @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); - } - - @Override - public long[] allocate(int length) { - assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; - return new long[length * 2]; - } - -} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java deleted file mode 100644 index cfb85ea55bcd6..0000000000000 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -import java.util.Comparator; - -import org.apache.spark.util.collection.Sorter; -import org.apache.spark.unsafe.memory.TaskMemoryManager; -import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.RecordPointerAndKeyPrefix; - -/** - * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records - * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm - * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, - * then we do not need to traverse the record pointers to compare the actual records. Avoiding these - * random memory accesses improves cache hit rates. - */ -public final class UnsafeSorter { - - private final TaskMemoryManager memoryManager; - private final Sorter sorter; - private final Comparator sortComparator; - - /** - * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at - * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. - */ - private long[] sortBuffer; - - /** - * The position in the sort buffer where new records can be inserted. - */ - private int sortBufferInsertPosition = 0; - - public void expandSortBuffer() { - final long[] oldBuffer = sortBuffer; - sortBuffer = new long[oldBuffer.length * 2]; - System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); - } - - public UnsafeSorter( - final TaskMemoryManager memoryManager, - final RecordComparator recordComparator, - final PrefixComparator prefixComparator, - int initialSize) { - assert (initialSize > 0); - this.sortBuffer = new long[initialSize * 2]; - this.memoryManager = memoryManager; - this.sorter = - new Sorter(UnsafeSortDataFormat.INSTANCE); - this.sortComparator = new Comparator() { - @Override - public int compare(RecordPointerAndKeyPrefix left, RecordPointerAndKeyPrefix right) { - final int prefixComparisonResult = - prefixComparator.compare(left.keyPrefix, right.keyPrefix); - if (prefixComparisonResult == 0) { - final Object leftBaseObject = memoryManager.getPage(left.recordPointer); - final long leftBaseOffset = memoryManager.getOffsetInPage(left.recordPointer); - final Object rightBaseObject = memoryManager.getPage(right.recordPointer); - final long rightBaseOffset = memoryManager.getOffsetInPage(right.recordPointer); - return recordComparator.compare( - leftBaseObject, leftBaseOffset, rightBaseObject, rightBaseOffset); - } else { - return prefixComparisonResult; - } - } - }; - } - - public long getMemoryUsage() { - return sortBuffer.length * 8L; - } - - public boolean hasSpaceForAnotherRecord() { - return sortBufferInsertPosition + 2 < sortBuffer.length; - } - - /** - * Insert a record into the sort buffer. - * - * @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}. - */ - public void insertRecord(long objectAddress, long keyPrefix) { - if (!hasSpaceForAnotherRecord()) { - expandSortBuffer(); - } - sortBuffer[sortBufferInsertPosition] = objectAddress; - sortBufferInsertPosition++; - sortBuffer[sortBufferInsertPosition] = keyPrefix; - sortBufferInsertPosition++; - } - - /** - * Return an iterator over record pointers in sorted order. For efficiency, all calls to - * {@code next()} will return the same mutable object. - */ - public UnsafeSorterIterator getSortedIterator() { - sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new UnsafeSorterIterator() { - - private int position = 0; - private Object baseObject; - private long baseOffset; - private long keyPrefix; - private int recordLength; - - @Override - public boolean hasNext() { - return position < sortBufferInsertPosition; - } - - @Override - public void loadNext() { - final long recordPointer = sortBuffer[position]; - baseObject = memoryManager.getPage(recordPointer); - baseOffset = memoryManager.getOffsetInPage(recordPointer); - keyPrefix = sortBuffer[position + 1]; - position += 2; - } - - @Override - public Object getBaseObject() { return baseObject; } - - @Override - public long getBaseOffset() { return baseOffset; } - - @Override - public int getRecordLength() { return recordLength; } - - @Override - public long getKeyPrefix() { return keyPrefix; } - }; - } -} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java deleted file mode 100644 index 0837843ae442b..0000000000000 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -import java.io.IOException; -import java.util.Comparator; -import java.util.PriorityQueue; - -final class UnsafeSorterSpillMerger { - - private final PriorityQueue priorityQueue; - - public UnsafeSorterSpillMerger( - final RecordComparator recordComparator, - final PrefixComparator prefixComparator) { - final Comparator comparator = new Comparator() { - - @Override - public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { - final int prefixComparisonResult = - prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); - if (prefixComparisonResult == 0) { - return recordComparator.compare( - left.getBaseObject(), left.getBaseOffset(), - right.getBaseObject(), right.getBaseOffset()); - } else { - return prefixComparisonResult; - } - } - }; - // TODO: the size is often known; incorporate size hints here. - priorityQueue = new PriorityQueue(10, comparator); - } - - public void addSpill(UnsafeSorterIterator spillReader) throws IOException { - if (spillReader.hasNext()) { - spillReader.loadNext(); - } - priorityQueue.add(spillReader); - } - - public UnsafeSorterIterator getSortedIterator() throws IOException { - return new UnsafeSorterIterator() { - - private UnsafeSorterIterator spillReader; - - @Override - public boolean hasNext() { - return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); - } - - @Override - public void loadNext() throws IOException { - if (spillReader != null) { - if (spillReader.hasNext()) { - spillReader.loadNext(); - priorityQueue.add(spillReader); - } - } - spillReader = priorityQueue.remove(); - } - - @Override - public Object getBaseObject() { return spillReader.getBaseObject(); } - - @Override - public long getBaseOffset() { return spillReader.getBaseOffset(); } - - @Override - public int getRecordLength() { return spillReader.getRecordLength(); } - - @Override - public long getKeyPrefix() { return spillReader.getKeyPrefix(); } - }; - } -} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java deleted file mode 100644 index 696bcd468b0a1..0000000000000 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -import java.io.*; - -import com.google.common.io.ByteStreams; - -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManager; -import org.apache.spark.unsafe.PlatformDependent; - -final class UnsafeSorterSpillReader extends UnsafeSorterIterator { - - private final File file; - private InputStream in; - private DataInputStream din; - - private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? - private int nextRecordLength; - - private long keyPrefix; - private final Object baseObject = arr; - private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; - - public UnsafeSorterSpillReader( - BlockManager blockManager, - File file, - BlockId blockId) throws IOException { - this.file = file; - assert (file.length() > 0); - final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); - this.in = blockManager.wrapForCompression(blockId, bs); - this.din = new DataInputStream(this.in); - nextRecordLength = din.readInt(); - } - - @Override - public boolean hasNext() { - return (in != null); - } - - @Override - public void loadNext() throws IOException { - keyPrefix = din.readLong(); - ByteStreams.readFully(in, arr, 0, nextRecordLength); - nextRecordLength = din.readInt(); - if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { - in.close(); - in = null; - din = null; - } - } - - @Override - public Object getBaseObject() { - return baseObject; - } - - @Override - public long getBaseOffset() { - return baseOffset; - } - - @Override - public int getRecordLength() { - return 0; - } - - @Override - public long getKeyPrefix() { - return keyPrefix; - } -} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java deleted file mode 100644 index 33356c3351967..0000000000000 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -import java.io.*; -import java.nio.ByteBuffer; - -import scala.Tuple2; -import scala.reflect.ClassTag; - -import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; -import org.apache.spark.storage.TempLocalBlockId; -import org.apache.spark.unsafe.PlatformDependent; - -final class UnsafeSorterSpillWriter { - - private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this - static final int EOF_MARKER = -1; - - private byte[] arr = new byte[SER_BUFFER_SIZE]; - - private final File file; - private final BlockId blockId; - private BlockObjectWriter writer; - private DataOutputStream dos; - - public UnsafeSorterSpillWriter( - BlockManager blockManager, - int fileBufferSize, - ShuffleWriteMetrics writeMetrics) { - final Tuple2 spilledFileInfo = - blockManager.diskBlockManager().createTempLocalBlock(); - this.file = spilledFileInfo._2(); - this.blockId = spilledFileInfo._1(); - // Dummy serializer: - final SerializerInstance ser = new SerializerInstance() { - @Override - public SerializationStream serializeStream(OutputStream s) { - return new SerializationStream() { - @Override - public void flush() { - - } - - @Override - public SerializationStream writeObject(T t, ClassTag ev1) { - return null; - } - - @Override - public void close() { - - } - }; - } - - @Override - public ByteBuffer serialize(T t, ClassTag ev1) { - return null; - } - - @Override - public DeserializationStream deserializeStream(InputStream s) { - return null; - } - - @Override - public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { - return null; - } - - @Override - public T deserialize(ByteBuffer bytes, ClassTag ev1) { - return null; - } - }; - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); - dos = new DataOutputStream(writer); - } - - public void write( - Object baseObject, - long baseOffset, - int recordLength, - long keyPrefix) throws IOException { - dos.writeInt(recordLength); - dos.writeLong(keyPrefix); - PlatformDependent.copyMemory( - baseObject, - baseOffset + 4, - arr, - PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength); - writer.write(arr, 0, recordLength); - // TODO: add a test that detects whether we leave this call out: - writer.recordWritten(); - } - - public void close() throws IOException { - dos.writeInt(EOF_MARKER); - writer.commitAndClose(); - writer = null; - dos = null; - arr = null; - } - - public long numberOfSpilledBytes() { - return file.length(); - } - - public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { - return new UnsafeSorterSpillReader(blockManager, file, blockId); - } -} \ No newline at end of file diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java new file mode 100644 index 0000000000000..080145b90554a --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Arrays; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class UnsafeShuffleSorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { + final byte[] strBytes = new byte[strLength]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final UnsafeShuffleSorter sorter = new UnsafeShuffleSorter(100); + final UnsafeShuffleSorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + @Test + public void testBasicSorting() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + final UnsafeShuffleSorter sorter = new UnsafeShuffleSorter(4); + final HashPartitioner hashPartitioner = new HashPartitioner(4); + + // Write the records into the data page and store pointers into the sorter + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + position += 4; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); + } + + // Sort the records + final UnsafeShuffleSorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + int prevPartitionId = -1; + Arrays.sort(dataToSort); + for (int i = 0; i < dataToSort.length; i++) { + Assert.assertTrue(iter.hasNext()); + iter.loadNext(); + final int partitionId = iter.packedRecordPointer.getPartitionId(); + Assert.assertTrue(partitionId >= 0 && partitionId <= 3); + Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, + partitionId >= prevPartitionId); + final long recordAddress = iter.packedRecordPointer.getRecordPointer(); + final int recordLength = PlatformDependent.UNSAFE.getInt( + memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); + final String str = getStringFromDataPage( + memoryManager.getPage(recordAddress), + memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length + recordLength); + Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1); + } + Assert.assertFalse(iter.hasNext()); + } +} diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java deleted file mode 100644 index 9cb96fa2c3322..0000000000000 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -import java.io.File; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.UUID; - -import scala.Tuple2; -import scala.Tuple2$; -import scala.runtime.AbstractFunction1; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import static org.mockito.AdditionalAnswers.returnsFirstArg; -import static org.mockito.AdditionalAnswers.returnsSecondArg; -import static org.mockito.Mockito.*; - -import org.apache.spark.HashPartitioner; -import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; -import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.executor.TaskMetrics; -import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.storage.*; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.util.Utils; - -public class UnsafeExternalSorterSuite { - - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - // Compute key prefixes based on the records' partition ids - final HashPartitioner hashPartitioner = new HashPartitioner(4); - // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final PrefixComparator prefixComparator = new PrefixComparator() { - @Override - public int compare(long prefix1, long prefix2) { - return (int) prefix1 - (int) prefix2; - } - }; - // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so - // use a dummy comparator - final RecordComparator recordComparator = new RecordComparator() { - @Override - public int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset) { - return 0; - } - }; - - ShuffleMemoryManager shuffleMemoryManager; - BlockManager blockManager; - DiskBlockManager diskBlockManager; - File tempDir; - TaskContext taskContext; - - private static final class CompressStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } - - @Before - public void setUp() { - shuffleMemoryManager = mock(ShuffleMemoryManager.class); - diskBlockManager = mock(DiskBlockManager.class); - blockManager = mock(BlockManager.class); - tempDir = new File(Utils.createTempDir$default$1()); - taskContext = mock(TaskContext.class); - when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); - when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); - when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { - @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - return Tuple2$.MODULE$.apply(blockId, file); - } - }); - when(blockManager.getDiskWriter( - any(BlockId.class), - any(File.class), - any(SerializerInstance.class), - anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { - @Override - public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { - Object[] args = invocationOnMock.getArguments(); - - return new DiskBlockObjectWriter( - (BlockId) args[0], - (File) args[1], - (SerializerInstance) args[2], - (Integer) args[3], - new CompressStream(), - false, - (ShuffleWriteMetrics) args[4] - ); - } - }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); - } - - private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { - final int[] arr = new int[] { value }; - sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); - } - - /** - * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. - */ - @Test - public void testSortingOnlyByPartitionId() throws Exception { - - final UnsafeExternalSorter sorter = new UnsafeExternalSorter( - memoryManager, - shuffleMemoryManager, - blockManager, - taskContext, - recordComparator, - prefixComparator, - 1024, - new SparkConf()); - - insertNumber(sorter, 5); - insertNumber(sorter, 1); - insertNumber(sorter, 3); - sorter.spill(); - insertNumber(sorter, 4); - insertNumber(sorter, 2); - - UnsafeSorterIterator iter = sorter.getSortedIterator(); - - iter.loadNext(); - Assert.assertEquals(1, iter.getKeyPrefix()); - iter.loadNext(); - Assert.assertEquals(2, iter.getKeyPrefix()); - iter.loadNext(); - Assert.assertEquals(3, iter.getKeyPrefix()); - iter.loadNext(); - Assert.assertEquals(4, iter.getKeyPrefix()); - iter.loadNext(); - Assert.assertEquals(5, iter.getKeyPrefix()); - Assert.assertFalse(iter.hasNext()); - // TODO: check that the values are also read back properly. - - // TODO: test for cleanup: - // assert(tempDir.isEmpty) - } - -} diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java deleted file mode 100644 index 3a2e9696761de..0000000000000 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.sort; - -import java.util.Arrays; - -import org.junit.Assert; -import org.junit.Test; -import static org.mockito.Mockito.*; - -import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; - -public class UnsafeSorterSuite { - - private static String getStringFromDataPage(Object baseObject, long baseOffset) { - final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); - final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + 8, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); - return new String(strBytes); - } - - @Test - public void testSortingEmptyInput() { - final UnsafeSorter sorter = new UnsafeSorter( - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), - mock(RecordComparator.class), - mock(PrefixComparator.class), - 100); - final UnsafeSorterIterator iter = sorter.getSortedIterator(); - assert(!iter.hasNext()); - } - - /** - * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. - */ - @Test - public void testSortingOnlyByPartitionId() throws Exception { - final String[] dataToSort = new String[] { - "Boba", - "Pearls", - "Tapioca", - "Taho", - "Condensed Milk", - "Jasmine", - "Milk Tea", - "Lychee", - "Mango" - }; - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock dataPage = memoryManager.allocatePage(2048); - final Object baseObject = dataPage.getBaseObject(); - // Write the records into the data page: - long position = dataPage.getBaseOffset(); - for (String str : dataToSort) { - final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); - position += 8; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); - position += strBytes.length; - } - // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so - // use a dummy comparator - final RecordComparator recordComparator = new RecordComparator() { - @Override - public int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset) { - return 0; - } - }; - // Compute key prefixes based on the records' partition ids - final HashPartitioner hashPartitioner = new HashPartitioner(4); - // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final PrefixComparator prefixComparator = new PrefixComparator() { - @Override - public int compare(long prefix1, long prefix2) { - return (int) prefix1 - (int) prefix2; - } - }; - final UnsafeSorter sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComparator, - dataToSort.length); - // Given a page of records, insert those records into the sorter one-by-one: - position = dataPage.getBaseOffset(); - for (int i = 0; i < dataToSort.length; i++) { - // position now points to the start of a record (which holds its length). - final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); - final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); - final String str = getStringFromDataPage(baseObject, position); - final int partitionId = hashPartitioner.getPartition(str); - sorter.insertRecord(address, partitionId); - position += 8 + recordLength; - } - final UnsafeSorterIterator iter = sorter.getSortedIterator(); - int iterLength = 0; - long prevPrefix = -1; - Arrays.sort(dataToSort); - while (iter.hasNext()) { - iter.loadNext(); - final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset()); - final long keyPrefix = iter.getKeyPrefix(); - Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); - Assert.assertTrue("Prefix " + keyPrefix + " should be >= previous prefix " + - prevPrefix, keyPrefix >= prevPrefix); - prevPrefix = keyPrefix; - iterLength++; - } - Assert.assertEquals(dataToSort.length, iterLength); - } -} From 133c8c96f61a6809f6bad540a1b6d34e613febf1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 5 May 2015 10:47:43 -0700 Subject: [PATCH 34/92] WIP towards testing UnsafeShuffleWriter. Unfortunately, this involved a TON of mocks; maybe it would be easier to split the writer into more objects, such as a spiller and merger, as I did when the sorting code was more generic. --- .../unsafe/UnsafeShuffleSpillWriter.java | 16 +- .../shuffle/unsafe/UnsafeShuffleWriter.java | 39 +++- .../shuffle/unsafe/UnsafeShuffleManager.scala | 7 +- .../unsafe/UnsafeShuffleWriterSuite.java | 172 ++++++++++++++++++ 4 files changed, 215 insertions(+), 19 deletions(-) create mode 100644 core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java index b0e2b6022ef21..8e0a21ec6b3a5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java @@ -69,13 +69,13 @@ public final class UnsafeShuffleSpillWriter { private final LinkedList spills = new LinkedList(); public UnsafeShuffleSpillWriter( - TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, - BlockManager blockManager, - TaskContext taskContext, - int initialSize, - int numPartitions, - SparkConf conf) throws IOException { + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; @@ -266,7 +266,7 @@ public SpillInfo[] closeAndGetSpills() throws IOException { if (sorter != null) { writeSpillFile(); } - return (SpillInfo[]) spills.toArray(); + return spills.toArray(new SpillInfo[0]); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 839c854963ccf..47fe214634abb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -23,9 +23,12 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.util.Iterator; +import org.apache.spark.shuffle.ShuffleMemoryManager; import scala.Option; import scala.Product2; +import scala.collection.JavaConversions; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -50,14 +53,18 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + private final BlockManager blockManager; private final IndexShuffleBlockManager shuffleBlockManager; - private final BlockManager blockManager = SparkEnv.get().blockManager(); - private final int shuffleId; - private final int mapId; private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private MapStatus mapStatus = null; /** @@ -68,19 +75,31 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private boolean stopping = false; public UnsafeShuffleWriter( + BlockManager blockManager, IndexShuffleBlockManager shuffleBlockManager, + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, UnsafeShuffleHandle handle, int mapId, - TaskContext context) { + TaskContext taskContext, + SparkConf sparkConf) { + this.blockManager = blockManager; this.shuffleBlockManager = shuffleBlockManager; + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; this.mapId = mapId; - this.memoryManager = context.taskMemoryManager(); final ShuffleDependency dep = handle.dependency(); this.shuffleId = dep.shuffleId(); this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = new ShuffleWriteMetrics(); - context.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.taskContext = taskContext; + this.sparkConf = sparkConf; + } + + public void write(Iterator> records) { + write(JavaConversions.asScalaIterator(records)); } public void write(scala.collection.Iterator> records) { @@ -101,12 +120,12 @@ private SpillInfo[] insertRecordsIntoSorter( scala.collection.Iterator> records) throws Exception { final UnsafeShuffleSpillWriter sorter = new UnsafeShuffleSpillWriter( memoryManager, - SparkEnv$.MODULE$.get().shuffleMemoryManager(), - SparkEnv$.MODULE$.get().blockManager(), - TaskContext.get(), + shuffleMemoryManager, + blockManager, + taskContext, 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), - SparkEnv$.MODULE$.get().conf() + sparkConf ); final byte[] serArray = new byte[SER_BUFFER_SIZE]; diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 0dd34b372f624..14f29a36ec4f6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -88,12 +88,17 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage context: TaskContext): ShuffleWriter[K, V] = { handle match { case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + val env = SparkEnv.get // TODO: do we need to do anything to register the shuffle here? new UnsafeShuffleWriter( + env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockManager], + context.taskMemoryManager(), + env.shuffleMemoryManager, unsafeShuffleHandle, mapId, - context) + context, + env.conf) case other => sortShuffleManager.getWriter(handle, mapId, context) } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java new file mode 100644 index 0000000000000..8ba548420bd4b --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.UUID; + +import scala.*; +import scala.runtime.AbstractFunction1; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Mockito.*; + +import org.apache.spark.*; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.shuffle.IndexShuffleBlockManager; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.scheduler.MapStatus; + +public class UnsafeShuffleWriterSuite { + + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + + ShuffleMemoryManager shuffleMemoryManager; + BlockManager blockManager; + IndexShuffleBlockManager shuffleBlockManager; + DiskBlockManager diskBlockManager; + File tempDir; + TaskContext taskContext; + SparkConf sparkConf; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + shuffleMemoryManager = mock(ShuffleMemoryManager.class); + diskBlockManager = mock(DiskBlockManager.class); + blockManager = mock(BlockManager.class); + shuffleBlockManager = mock(IndexShuffleBlockManager.class); + tempDir = new File(Utils.createTempDir$default$1()); + taskContext = mock(TaskContext.class); + sparkConf = new SparkConf(); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + @Test + public void basicShuffleWriting() throws Exception { + + final ShuffleDependency dep = mock(ShuffleDependency.class); + when(dep.serializer()).thenReturn(Option.apply(new KryoSerializer(sparkConf))); + when(dep.partitioner()).thenReturn(hashPartitioner); + + final File mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); + when(shuffleBlockManager.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + final long[] partitionSizes = new long[hashPartitioner.numPartitions()]; + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + long[] receivedPartitionSizes = (long[]) invocationOnMock.getArguments()[2]; + System.arraycopy( + receivedPartitionSizes, 0, partitionSizes, 0, receivedPartitionSizes.length); + return null; + } + }).when(shuffleBlockManager).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + + final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( + blockManager, + shuffleBlockManager, + memoryManager, + shuffleMemoryManager, + new UnsafeShuffleHandle(0, 1, dep), + 0, // map id + taskContext, + sparkConf + ); + + final ArrayList> numbersToSort = + new ArrayList>(); + numbersToSort.add(new Tuple2(5, 5)); + numbersToSort.add(new Tuple2(1, 1)); + numbersToSort.add(new Tuple2(3, 3)); + numbersToSort.add(new Tuple2(2, 2)); + numbersToSort.add(new Tuple2(4, 4)); + + + writer.write(numbersToSort.iterator()); + final MapStatus mapStatus = writer.stop(true).get(); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizes) { + sumOfPartitionSizes += size; + } + Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + + // TODO: test that the temporary spill files were cleaned up after the merge. + } + +} From 4f70141aa6949d9251719790551e184cac66a05b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 5 May 2015 11:36:01 -0700 Subject: [PATCH 35/92] Fix merging; now passes UnsafeShuffleSuite tests. --- .../unsafe/UnsafeShuffleSpillWriter.java | 7 +++-- .../shuffle/unsafe/UnsafeShuffleWriter.java | 26 ++++++++++++++----- .../unsafe/UnsafeShuffleWriterSuite.java | 7 ++++- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java index 8e0a21ec6b3a5..fd2c170bd2e41 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java @@ -118,7 +118,6 @@ private SpillInfo writeSpillFile() throws IOException { final File file = spilledFileInfo._2(); final BlockId blockId = spilledFileInfo._1(); final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); - spills.add(spillInfo); final SerializerInstance ser = new DummySerializerInstance(); writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); @@ -154,7 +153,11 @@ private SpillInfo writeSpillFile() throws IOException { if (writer != null) { writer.commitAndClose(); - spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + // TODO: comment and explain why our handling of empty spills, etc. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + spills.add(spillInfo); + } } return spillInfo; } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 47fe214634abb..ad842502bf24f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -157,7 +157,14 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; + + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); + return partitionLengths; + } + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + final long[] spillInputChannelPositions = new long[spills.length]; // TODO: We need to add an option to bypass transferTo here since older Linux kernels are // affected by a bug here that can lead to data truncation; see the comments Utils.scala, @@ -173,24 +180,29 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { final FileChannel mergedFileOutputChannel = new FileOutputStream(outputFile).getChannel(); - for (int partition = 0; partition < numPartitions; partition++ ) { + for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { - final long bytesToTransfer = spills[i].partitionLengths[partition]; - long bytesRemainingToBeTransferred = bytesToTransfer; + System.out.println("In partition " + partition + " and spill " + i ); + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + System.out.println("Partition length in spill is " + partitionLengthInSpill); + System.out.println("input channel position is " + spillInputChannels[i].position()); + long bytesRemainingToBeTransferred = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; - long fromPosition = spillInputChannel.position(); while (bytesRemainingToBeTransferred > 0) { - bytesRemainingToBeTransferred -= spillInputChannel.transferTo( - fromPosition, + final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannelPositions[i], bytesRemainingToBeTransferred, mergedFileOutputChannel); + spillInputChannelPositions[i] += actualBytesTransferred; + bytesRemainingToBeTransferred -= actualBytesTransferred; } - partitionLengths[partition] += bytesToTransfer; + partitionLengths[partition] += partitionLengthInSpill; } } // TODO: should this be in a finally block? for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); spillInputChannels[i].close(); } mergedFileOutputChannel.close(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 8ba548420bd4b..9008cc2de9bd5 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -158,7 +158,8 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { writer.write(numbersToSort.iterator()); - final MapStatus mapStatus = writer.stop(true).get(); + final Option mapStatus = writer.stop(true); + Assert.assertTrue(mapStatus.isDefined()); long sumOfPartitionSizes = 0; for (long size: partitionSizes) { @@ -166,6 +167,10 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { } Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + // TODO: actually try to read the shuffle output? + + // TODO: add a test that manually triggers spills in order to exercise the merging. + // TODO: test that the temporary spill files were cleaned up after the merge. } From aaea17b5c07a1b3d1ebe99020eee31eb1d9d87e1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 5 May 2015 18:29:29 -0700 Subject: [PATCH 36/92] Add comments to UnsafeShuffleSpillWriter. --- .../spark/shuffle/unsafe/SpillInfo.java | 3 + .../unsafe/UnsafeShuffleSpillWriter.java | 145 +++++++++++++----- 2 files changed, 110 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java index 5e8c090405098..a1b5266631164 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -21,6 +21,9 @@ import java.io.File; +/** + * Metadata for a block of data written by {@link UnsafeShuffleSpillWriter}. + */ final class SpillInfo { final long[] partitionLengths; final File file; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java index fd2c170bd2e41..05cf2e7d0d3cc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java @@ -17,30 +17,41 @@ package org.apache.spark.shuffle.unsafe; -import com.google.common.annotations.VisibleForTesting; +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedList; + +import org.apache.spark.storage.*; +import scala.Tuple2; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; -import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import scala.Tuple2; - -import java.io.File; -import java.io.IOException; -import java.util.Iterator; -import java.util.LinkedList; /** - * External sorter based on {@link UnsafeShuffleSorter}. + * An external sorter that is specialized for sort-based shuffle. + *

+ * Incoming records are appended to data pages. When all records have been inserted (or when the + * current thread's shuffle memory limit is reached), the in-memory records are sorted according to + * their partition ids (using a {@link UnsafeShuffleSorter}). The sorted records are then written + * to a single output file (or multiple files, if we've spilled). The format of the output files is + * the same as the format of the final output file written by + * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are + * written as a single serialized, compressed stream that can be read with a new decompression and + * deserialization stream. + *

+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its + * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a + * specialized merge procedure that avoids extra serialization/deserialization. */ public final class UnsafeShuffleSpillWriter { @@ -51,23 +62,31 @@ public final class UnsafeShuffleSpillWriter { private final int initialSize; private final int numPartitions; - private UnsafeShuffleSorter sorter; - private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; - private final LinkedList allocatedPages = new LinkedList(); private final boolean spillingEnabled; - private final int fileBufferSize; private ShuffleWriteMetrics writeMetrics; + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSize; - private MemoryBlock currentPage = null; - private long currentPagePosition = -1; + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); private final LinkedList spills = new LinkedList(); + // All three of these variables are reset after spilling: + private UnsafeShuffleSorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + public UnsafeShuffleSpillWriter( TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, @@ -90,6 +109,10 @@ public UnsafeShuffleSpillWriter( // TODO: metrics tracking + integration with shuffle write metrics + /** + * Allocates a new sorter. Called when opening the spill writer for the first time and after + * each spill. + */ private void openSorter() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); // TODO: connect write metrics to task metrics? @@ -106,22 +129,41 @@ private void openSorter() throws IOException { this.sorter = new UnsafeShuffleSorter(initialSize); } + /** + * Sorts the in-memory records, writes the sorted records to a spill file, and frees the in-memory + * data structures associated with this sort. New data structures are not automatically allocated. + */ private SpillInfo writeSpillFile() throws IOException { - final UnsafeShuffleSorter.UnsafeShuffleSorterIterator sortedRecords = sorter.getSortedIterator(); + // This call performs the actual sort. + final UnsafeShuffleSorter.UnsafeShuffleSorterIterator sortedRecords = + sorter.getSortedIterator(); - int currentPartition = -1; + // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this + // after SPARK-5581 is fixed. BlockObjectWriter writer = null; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // records in a byte array. This array only needs to be big enough to hold a single record. final byte[] arr = new byte[SER_BUFFER_SIZE]; - final Tuple2 spilledFileInfo = - blockManager.diskBlockManager().createTempLocalBlock(); + // Because this output will be read during shuffle, its compression codec must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more details. + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempShuffleBlock(); final File file = spilledFileInfo._2(); final BlockId blockId = spilledFileInfo._1(); final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. final SerializerInstance ser = new DummySerializerInstance(); writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); + int currentPartition = -1; while (sortedRecords.hasNext()) { sortedRecords.loadNext(); final int partition = sortedRecords.packedRecordPointer.getPartitionId(); @@ -153,7 +195,9 @@ private SpillInfo writeSpillFile() throws IOException { if (writer != null) { writer.commitAndClose(); - // TODO: comment and explain why our handling of empty spills, etc. + // If `writeSpillFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the spill file might be empty. Note that it might be better to avoid calling + // writeSpillFile() in that case. if (currentPartition != -1) { spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); spills.add(spillInfo); @@ -162,24 +206,30 @@ private SpillInfo writeSpillFile() throws IOException { return spillInfo; } - @VisibleForTesting - public void spill() throws IOException { - final SpillInfo spillInfo = writeSpillFile(); + /** + * Sort and spill the current records in response to memory pressure. + */ + private void spill() throws IOException { + final long threadId = Thread.currentThread().getId(); + logger.info("Thread " + threadId + " spilling sort data of " + + org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" + + (spills.size() + (spills.size() > 1 ? " times" : " time")) + " so far)"); + final SpillInfo spillInfo = writeSpillFile(); final long sorterMemoryUsage = sorter.getMemoryUsage(); sorter = null; shuffleMemoryManager.release(sorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); taskContext.taskMetrics().incDiskBytesSpilled(spillInfo.file.length()); - final long threadId = Thread.currentThread().getId(); - // TODO: messy; log _before_ spill - logger.info("Thread " + threadId + " spilling in-memory map of " + - org.apache.spark.util.Utils.bytesToString(spillSize) + " to disk (" + - (spills.size() + ((spills.size() > 1) ? " times" : " time")) + " so far)"); + openSorter(); } + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * PAGE_SIZE); + } + private long freeMemory() { long memoryFreed = 0; final Iterator iter = allocatedPages.iterator(); @@ -194,7 +244,15 @@ private long freeMemory() { return memoryFreed; } - private void ensureSpaceInDataPage(int requiredSpace) throws Exception { + /** + * Checks whether there is enough space to insert a new record into the sorter. If there is + * insufficient space, either allocate more memory or spill the current sort data (if spilling + * is enabled), then insert the record. + */ + private void ensureSpaceInDataPage(int requiredSpace) throws IOException { + // TODO: we should re-order the `if` cases in this function so that the most common case (there + // is enough space) appears first. + // TODO: merge these steps to first calculate total memory requirements for this insert, // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the // data page. @@ -219,7 +277,7 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception { } if (requiredSpace > PAGE_SIZE) { // TODO: throw a more specific exception? - throw new Exception("Required space " + requiredSpace + " is greater than page size (" + + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + PAGE_SIZE + ")"); } else if (requiredSpace > spaceInCurrentPage) { if (spillingEnabled) { @@ -230,7 +288,7 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception { final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); if (memoryAcquiredAfterSpill != PAGE_SIZE) { shuffleMemoryManager.release(memoryAcquiredAfterSpill); - throw new Exception("Can't allocate memory!"); + throw new IOException("Can't allocate memory!"); } } } @@ -241,11 +299,14 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception { } } + /** + * Write a record to the shuffle sorter. + */ public void insertRecord( Object recordBaseObject, long recordBaseOffset, int lengthInBytes, - int prefix) throws Exception { + int partitionId) throws IOException { // Need 4 bytes to store the record length. ensureSpaceInDataPage(lengthInBytes + 4); @@ -262,12 +323,20 @@ public void insertRecord( lengthInBytes); currentPagePosition += lengthInBytes; - sorter.insertRecord(recordAddress, prefix); + sorter.insertRecord(recordAddress, partitionId); } + /** + * Close the sorter, causing any buffered data to be sorted and written out to disk. + * + * @return metadata for the spill files written by this sorter. If no records were ever inserted + * into this sorter, then this will return an empty array. + * @throws IOException + */ public SpillInfo[] closeAndGetSpills() throws IOException { if (sorter != null) { writeSpillFile(); + freeMemory(); } return spills.toArray(new SpillInfo[0]); } From 11feeb6dba843b8b596c159eca781367056e6eb5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 May 2015 17:27:08 -0700 Subject: [PATCH 37/92] Update TODOs related to shuffle write metrics. --- .../unsafe/UnsafeShuffleSpillWriter.java | 25 +++++++++++-------- .../shuffle/unsafe/UnsafeShuffleWriter.java | 7 +----- .../shuffle/FileShuffleBlockManager.scala | 2 +- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java index 05cf2e7d0d3cc..68f5b080572ea 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java @@ -67,7 +67,6 @@ public final class UnsafeShuffleSpillWriter { private final BlockManager blockManager; private final TaskContext taskContext; private final boolean spillingEnabled; - private ShuffleWriteMetrics writeMetrics; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSize; @@ -107,15 +106,11 @@ public UnsafeShuffleSpillWriter( openSorter(); } - // TODO: metrics tracking + integration with shuffle write metrics - /** * Allocates a new sorter. Called when opening the spill writer for the first time and after * each spill. */ private void openSorter() throws IOException { - this.writeMetrics = new ShuffleWriteMetrics(); - // TODO: connect write metrics to task metrics? // TODO: move this sizing calculation logic into a static method of sorter: final long memoryRequested = initialSize * 8L; if (spillingEnabled) { @@ -130,8 +125,8 @@ private void openSorter() throws IOException { } /** - * Sorts the in-memory records, writes the sorted records to a spill file, and frees the in-memory - * data structures associated with this sort. New data structures are not automatically allocated. + * Sorts the in-memory records and writes the sorted records to a spill file. + * This method does not free the sort data structures. */ private SpillInfo writeSpillFile() throws IOException { // This call performs the actual sort. @@ -161,7 +156,17 @@ private SpillInfo writeSpillFile() throws IOException { // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work // around this, we pass a dummy no-op serializer. final SerializerInstance ser = new DummySerializerInstance(); - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); + // TODO: audit the metrics-related code and ensure proper metrics integration: + // It's not clear how we should handle shuffle write metrics for spill files; currently, Spark + // doesn't report IO time spent writing spill files (see SPARK-7413). This method, + // writeSpillFile(), is called both when writing spill files and when writing the single output + // file in cases where we didn't spill. As a result, we don't necessarily know whether this + // should be reported as bytes spilled or as shuffle bytes written. We could defer the updating + // of these metrics until the end of the shuffle write, but that would mean that that users + // wouldn't get useful metrics updates in the UI from long-running tasks. Given this complexity, + // I'm deferring these decisions to a separate follow-up commit or patch. + writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, new ShuffleWriteMetrics()); int currentPartition = -1; while (sortedRecords.hasNext()) { @@ -175,7 +180,8 @@ private SpillInfo writeSpillFile() throws IOException { spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); } currentPartition = partition; - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); + writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, new ShuffleWriteMetrics()); } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); @@ -295,7 +301,6 @@ private void ensureSpaceInDataPage(int requiredSpace) throws IOException { currentPage = memoryManager.allocatePage(PAGE_SIZE); currentPagePosition = currentPage.getBaseOffset(); allocatedPages.add(currentPage); - logger.info("Acquired new page! " + allocatedPages.size() * PAGE_SIZE); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index ad842502bf24f..80e01109eabd4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -125,8 +125,7 @@ private SpillInfo[] insertRecordsIntoSorter( taskContext, 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), - sparkConf - ); + sparkConf); final byte[] serArray = new byte[SER_BUFFER_SIZE]; final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray); @@ -182,10 +181,7 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { - System.out.println("In partition " + partition + " and spill " + i ); final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - System.out.println("Partition length in spill is " + partitionLengthInSpill); - System.out.println("input channel position is " + spillInputChannels[i].position()); long bytesRemainingToBeTransferred = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; while (bytesRemainingToBeTransferred > 0) { @@ -228,7 +224,6 @@ public Option stop(boolean success) { } } finally { freeMemory(); - // TODO: increment the shuffle write time metrics } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 0a84fdc0e4ca2..e9b4e2b955dc8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} -/** A group of writers for ShuffleMapTask, one writer per reducer. */ +/** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { val writers: Array[BlockObjectWriter] From 8a6fe52af6fadc2841173bd098b1d7a87e2307b7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 May 2015 17:45:59 -0700 Subject: [PATCH 38/92] Rename UnsafeShuffleSpillWriter to UnsafeShuffleExternalSorter --- .../spark/shuffle/unsafe/SpillInfo.java | 2 +- ....java => UnsafeShuffleExternalSorter.java} | 20 +++++++++---------- .../shuffle/unsafe/UnsafeShuffleWriter.java | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) rename core/src/main/java/org/apache/spark/shuffle/unsafe/{UnsafeShuffleSpillWriter.java => UnsafeShuffleExternalSorter.java} (97%) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java index a1b5266631164..5435c2c98428f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -22,7 +22,7 @@ import java.io.File; /** - * Metadata for a block of data written by {@link UnsafeShuffleSpillWriter}. + * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. */ final class SpillInfo { final long[] partitionLengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java similarity index 97% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java rename to core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 68f5b080572ea..64ef0f2c07820 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -53,9 +53,9 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -public final class UnsafeShuffleSpillWriter { +public final class UnsafeShuffleExternalSorter { - private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleSpillWriter.class); + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this / don't duplicate private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this @@ -86,14 +86,14 @@ public final class UnsafeShuffleSpillWriter { private MemoryBlock currentPage = null; private long currentPagePosition = -1; - public UnsafeShuffleSpillWriter( - TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, - BlockManager blockManager, - TaskContext taskContext, - int initialSize, - int numPartitions, - SparkConf conf) throws IOException { + public UnsafeShuffleExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 80e01109eabd4..995754901e09b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -118,7 +118,7 @@ private void freeMemory() { private SpillInfo[] insertRecordsIntoSorter( scala.collection.Iterator> records) throws Exception { - final UnsafeShuffleSpillWriter sorter = new UnsafeShuffleSpillWriter( + final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter( memoryManager, shuffleMemoryManager, blockManager, From cfe0ec4c57c21aeffc36b309333a02029c8ecad8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 May 2015 18:14:47 -0700 Subject: [PATCH 39/92] Address a number of minor review comments: --- .../unsafe/DummySerializerInstance.java | 31 ++++++++++++------- .../unsafe/UnsafeShuffleExternalSorter.java | 13 ++++---- .../unsafe/UnsafeShuffleSortDataFormat.java | 1 - .../shuffle/unsafe/UnsafeShuffleSorter.java | 22 ++++++------- .../spark/storage/BlockObjectWriter.scala | 8 +---- .../unsafe/UnsafeShuffleWriterSuite.java | 1 + 6 files changed, 37 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java index ab174c3ca921a..1d31a46993a22 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java @@ -26,44 +26,51 @@ import java.io.OutputStream; import java.nio.ByteBuffer; -class DummySerializerInstance extends SerializerInstance { +/** + * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + * Our shuffle write path doesn't actually use this serializer (since we end up calling the + * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + * around this, we pass a dummy no-op serializer. + */ +final class DummySerializerInstance extends SerializerInstance { + + public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); + + private DummySerializerInstance() { } + @Override public SerializationStream serializeStream(OutputStream s) { return new SerializationStream() { @Override - public void flush() { - - } + public void flush() { } @Override public SerializationStream writeObject(T t, ClassTag ev1) { - return null; + throw new UnsupportedOperationException(); } @Override - public void close() { - - } + public void close() { } }; } @Override public ByteBuffer serialize(T t, ClassTag ev1) { - return null; + throw new UnsupportedOperationException(); } @Override public DeserializationStream deserializeStream(InputStream s) { - return null; + throw new UnsupportedOperationException(); } @Override public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { - return null; + throw new UnsupportedOperationException(); } @Override public T deserialize(ByteBuffer bytes, ClassTag ev1) { - return null; + throw new UnsupportedOperationException(); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 64ef0f2c07820..10efa670dda1e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -155,7 +155,7 @@ private SpillInfo writeSpillFile() throws IOException { // Our write path doesn't actually use this serializer (since we end up calling the `write()` // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work // around this, we pass a dummy no-op serializer. - final SerializerInstance ser = new DummySerializerInstance(); + final SerializerInstance ser = DummySerializerInstance.INSTANCE; // TODO: audit the metrics-related code and ensure proper metrics integration: // It's not clear how we should handle shuffle write metrics for spill files; currently, Spark // doesn't report IO time spent writing spill files (see SPARK-7413). This method, @@ -238,13 +238,12 @@ private long getMemoryUsage() { private long freeMemory() { long memoryFreed = 0; - final Iterator iter = allocatedPages.iterator(); - while (iter.hasNext()) { - memoryManager.freePage(iter.next()); - shuffleMemoryManager.release(PAGE_SIZE); - memoryFreed += PAGE_SIZE; - iter.remove(); + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); } + allocatedPages.clear(); currentPage = null; currentPagePosition = -1; return memoryFreed; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java index d7afa1a906428..862845180584e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -61,7 +61,6 @@ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length @Override public long[] allocate(int length) { - assert (length < Integer.MAX_VALUE) : "Length " + length + " is too large"; return new long[length]; } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java index eb46776efe12c..d9ffe9a44fec7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java @@ -24,7 +24,13 @@ public final class UnsafeShuffleSorter { private final Sorter sorter; - private final Comparator sortComparator; + private static final class SortComparator implements Comparator { + @Override + public int compare(PackedRecordPointer left, PackedRecordPointer right) { + return left.getPartitionId() - right.getPartitionId(); + } + } + private static final SortComparator SORT_COMPARATOR = new SortComparator(); private long[] sortBuffer; @@ -36,14 +42,7 @@ public final class UnsafeShuffleSorter { public UnsafeShuffleSorter(int initialSize) { assert (initialSize > 0); this.sortBuffer = new long[initialSize]; - this.sorter = - new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); - this.sortComparator = new Comparator() { - @Override - public int compare(PackedRecordPointer left, PackedRecordPointer right) { - return left.getPartitionId() - right.getPartitionId(); - } - }; + this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); } public void expandSortBuffer() { @@ -81,11 +80,10 @@ public static abstract class UnsafeShuffleSorterIterator { } /** - * Return an iterator over record pointers in sorted order. For efficiency, all calls to - * {@code next()} will return the same mutable object. + * Return an iterator over record pointers in sorted order. */ public UnsafeShuffleSorterIterator getSortedIterator() { - sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator); + sorter.sort(sortBuffer, 0, sortBufferInsertPosition, SORT_COMPARATOR); return new UnsafeShuffleSorterIterator() { private int position = 0; diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 0254025122237..8bc4e205bc3c6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -218,13 +218,7 @@ private[spark] class DiskBlockObjectWriter( recordWritten() } - override def write(b: Int): Unit = { - if (!initialized) { - open() - } - - bs.write(b) - } + override def write(b: Int): Unit = throw new UnsupportedOperationException() override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { if (!initialized) { diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 9008cc2de9bd5..55c447327ef35 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -118,6 +118,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th } @Test + @SuppressWarnings("unchecked") public void basicShuffleWriting() throws Exception { final ShuffleDependency dep = mock(ShuffleDependency.class); From e67f1ea9832a02113a8525c1bb6689553d215be0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 May 2015 18:24:26 -0700 Subject: [PATCH 40/92] Remove upper type bound in ShuffleWriter interface. This bound wasn't necessary and was causing IntelliJ to display spurious errors when editing UnsafeShuffleWriter.java. --- .../org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java | 4 ++-- .../main/scala/org/apache/spark/shuffle/ShuffleWriter.scala | 4 ++-- .../org/apache/spark/shuffle/hash/HashShuffleWriter.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleWriter.scala | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 995754901e09b..6bf4384173fff 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -47,8 +47,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.TaskMemoryManager; -// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles -public class UnsafeShuffleWriter implements ShuffleWriter { +public class UnsafeShuffleWriter extends ShuffleWriter { private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @@ -102,6 +101,7 @@ public void write(Iterator> records) { write(JavaConversions.asScalaIterator(records)); } + @Override public void write(scala.collection.Iterator> records) { try { final long[] partitionLengths = mergeSpills(insertRecordsIntoSorter(records)); diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index f6e6fe5defe09..e28a2459cdff9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -22,9 +22,9 @@ import org.apache.spark.scheduler.MapStatus /** * Obtained inside a map task to write out records to the shuffle system. */ -private[spark] trait ShuffleWriter[K, V] { +private[spark] abstract class ShuffleWriter[K, V] { /** Write a sequence of records to this task's output */ - def write(records: Iterator[_ <: Product2[K, V]]): Unit + def write(records: Iterator[Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index cd27c9e07a3cd..8edfb9a054ada 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V]( writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { val iter = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { dep.aggregator.get.combineValuesByKey(records, context) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index a066435df6fb0..72864b36c5824 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -48,7 +48,7 @@ private[spark] class SortShuffleWriter[K, V, C]( context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") sorter = new ExternalSorter[K, V, C]( From 5e8cf751e92f949ae28ef303943442db2a7a9341 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 May 2015 22:44:35 -0700 Subject: [PATCH 41/92] More minor cleanup --- .../spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java | 1 - .../apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 10efa670dda1e..e04ea92a327ae 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -19,7 +19,6 @@ import java.io.File; import java.io.IOException; -import java.util.Iterator; import java.util.LinkedList; import org.apache.spark.storage.*; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 6bf4384173fff..1b5af45334238 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -182,15 +182,15 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - long bytesRemainingToBeTransferred = partitionLengthInSpill; + long bytesToTransfer = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; - while (bytesRemainingToBeTransferred > 0) { + while (bytesToTransfer > 0) { final long actualBytesTransferred = spillInputChannel.transferTo( spillInputChannelPositions[i], - bytesRemainingToBeTransferred, + bytesToTransfer, mergedFileOutputChannel); spillInputChannelPositions[i] += actualBytesTransferred; - bytesRemainingToBeTransferred -= actualBytesTransferred; + bytesToTransfer -= actualBytesTransferred; } partitionLengths[partition] += partitionLengthInSpill; } From 1ce13002b56055aff57482edc161fb20ed7b706b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 May 2015 22:44:36 -0700 Subject: [PATCH 42/92] More minor cleanup From b95e6425ddc9295e20c988a069576230568981c6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 6 May 2015 23:15:11 -0700 Subject: [PATCH 43/92] Refactor and document logic that decides when to spill. --- .../unsafe/UnsafeShuffleExternalSorter.java | 97 ++++++++++++------- 1 file changed, 60 insertions(+), 37 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index e04ea92a327ae..70c911252fddb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -84,6 +84,7 @@ public final class UnsafeShuffleExternalSorter { private UnsafeShuffleSorter sorter; private MemoryBlock currentPage = null; private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; public UnsafeShuffleExternalSorter( TaskMemoryManager memoryManager, @@ -245,22 +246,40 @@ private long freeMemory() { allocatedPages.clear(); currentPage = null; currentPagePosition = -1; + freeSpaceInCurrentPage = 0; return memoryFreed; } /** - * Checks whether there is enough space to insert a new record into the sorter. If there is - * insufficient space, either allocate more memory or spill the current sort data (if spilling - * is enabled), then insert the record. + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + logger.warn("Seeing if there's space for the record"); + assert (requiredSpace > 0); + // The sort array will automatically expand when inserting a new record, so we only need to + // worry about it having free space when spilling is enabled. + final boolean sortBufferHasSpace = !spillingEnabled || sorter.hasSpaceForAnotherRecord(); + final boolean dataPageHasSpace = requiredSpace <= freeSpaceInCurrentPage; + return (sortBufferHasSpace && dataPageHasSpace); + } + + /** + * Allocates more memory in order to insert an additional record. If spilling is enabled, this + * will request additional memory from the {@link ShuffleMemoryManager} and spill if the requested + * memory can not be obtained. If spilling is disabled, then this will allocate memory without + * coordinating with the ShuffleMemoryManager. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. */ - private void ensureSpaceInDataPage(int requiredSpace) throws IOException { - // TODO: we should re-order the `if` cases in this function so that the most common case (there - // is enough space) appears first. - - // TODO: merge these steps to first calculate total memory requirements for this insert, - // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the - // data page. - if (!sorter.hasSpaceForAnotherRecord() && spillingEnabled) { + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + if (spillingEnabled && !sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort buffer"); final long oldSortBufferMemoryUsage = sorter.getMemoryUsage(); final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2; final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowSortBuffer); @@ -272,33 +291,33 @@ private void ensureSpaceInDataPage(int requiredSpace) throws IOException { shuffleMemoryManager.release(oldSortBufferMemoryUsage); } } - - final long spaceInCurrentPage; - if (currentPage != null) { - spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); - } else { - spaceInCurrentPage = 0; - } - if (requiredSpace > PAGE_SIZE) { - // TODO: throw a more specific exception? - throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); - } else if (requiredSpace > spaceInCurrentPage) { - if (spillingEnabled) { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { - shuffleMemoryManager.release(memoryAcquired); - spill(); - final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpill != PAGE_SIZE) { - shuffleMemoryManager.release(memoryAcquiredAfterSpill); - throw new IOException("Can't allocate memory!"); + if (requiredSpace > freeSpaceInCurrentPage) { + logger.debug("Required space {} is less than free space in current page ({}}", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { + if (spillingEnabled) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Can't allocate memory!"); + } } } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); } - currentPage = memoryManager.allocatePage(PAGE_SIZE); - currentPagePosition = currentPage.getBaseOffset(); - allocatedPages.add(currentPage); } } @@ -311,13 +330,17 @@ public void insertRecord( int lengthInBytes, int partitionId) throws IOException { // Need 4 bytes to store the record length. - ensureSpaceInDataPage(lengthInBytes + 4); + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } final long recordAddress = memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); final Object dataPageBaseObject = currentPage.getBaseObject(); PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); currentPagePosition += 4; + freeSpaceInCurrentPage -= 4; PlatformDependent.copyMemory( recordBaseObject, recordBaseOffset, @@ -325,7 +348,7 @@ public void insertRecord( currentPagePosition, lengthInBytes); currentPagePosition += lengthInBytes; - + freeSpaceInCurrentPage -= lengthInBytes; sorter.insertRecord(recordAddress, partitionId); } From 722849b3f77dcbf3494f6a76a219d7f29b7a8284 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 May 2015 16:31:03 -0700 Subject: [PATCH 44/92] Add workaround for transferTo() bug in merging code; refactor tests. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 158 +++++++++++++---- .../unsafe/UnsafeShuffleWriterSuite.java | 161 +++++++++++------- 2 files changed, 225 insertions(+), 94 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 1b5af45334238..206812f8352d2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -25,7 +25,6 @@ import java.nio.channels.FileChannel; import java.util.Iterator; -import org.apache.spark.shuffle.ShuffleMemoryManager; import scala.Option; import scala.Product2; import scala.collection.JavaConversions; @@ -33,15 +32,21 @@ import scala.reflect.ClassTag$; import com.esotericsoftware.kryo.io.ByteBufferOutputStream; +import com.google.common.io.ByteStreams; +import com.google.common.io.Files; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.apache.spark.*; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockManager; +import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.PlatformDependent; @@ -49,6 +54,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @@ -63,6 +70,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final int mapId; private final TaskContext taskContext; private final SparkConf sparkConf; + private final boolean transferToEnabled; private MapStatus mapStatus = null; @@ -95,6 +103,7 @@ public UnsafeShuffleWriter( taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); this.taskContext = taskContext; this.sparkConf = sparkConf; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); } public void write(Iterator> records) { @@ -116,6 +125,10 @@ private void freeMemory() { // TODO } + private void deleteSpills() { + // TODO + } + private SpillInfo[] insertRecordsIntoSorter( scala.collection.Iterator> records) throws Exception { final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter( @@ -154,55 +167,127 @@ private SpillInfo[] insertRecordsIntoSorter( private long[] mergeSpills(SpillInfo[] spills) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); + try { + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Note: we'll have to watch out for corner-cases in this code path when working on shuffle + // metrics integration, since any metrics updates that are performed during the merge will + // also have to be done here. In this branch, the shuffle technically didn't need to spill + // because we're only trying to merge one file, so we may need to ensure that metrics that + // would otherwise be counted as spill metrics are actually counted as regular write + // metrics. + Files.move(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + // Need to merge multiple spills. + if (transferToEnabled) { + return mergeSpillsWithTransferTo(spills, outputFile); + } else { + return mergeSpillsWithFileStream(spills, outputFile); + } + } + } catch (IOException e) { + if (outputFile.exists() && !outputFile.delete()) { + logger.error("Unable to delete output file {}", outputFile.getPath()); + } + throw e; + } + } + + private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) throws IOException { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; + final FileInputStream[] spillInputStreams = new FileInputStream[spills.length]; + FileOutputStream mergedFileOutputStream = null; + + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new FileInputStream(spills[i].file); + } + mergedFileOutputStream = new FileOutputStream(outputFile); - if (spills.length == 0) { - new FileOutputStream(outputFile).close(); - return partitionLengths; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + final FileInputStream spillInputStream = spillInputStreams[i]; + ByteStreams.copy + (new LimitedInputStream(spillInputStream, partitionLengthInSpill), + mergedFileOutputStream); + partitionLengths[partition] += partitionLengthInSpill; + } + } + } finally { + for (int i = 0; i < spills.length; i++) { + if (spillInputStreams[i] != null) { + spillInputStreams[i].close(); + } + } + if (mergedFileOutputStream != null) { + mergedFileOutputStream.close(); + } } + return partitionLengths; + } + private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; + FileChannel mergedFileOutputChannel = null; - // TODO: We need to add an option to bypass transferTo here since older Linux kernels are - // affected by a bug here that can lead to data truncation; see the comments Utils.scala, - // in the copyStream() method. I didn't use copyStream() here because we only want to copy - // a limited number of bytes from the stream and I didn't want to modify / extend that method - // to accept a length. - - // TODO: special case optimization for case where we only write one file (non-spill case). - - for (int i = 0; i < spills.length; i++) { - spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); - } - - final FileChannel mergedFileOutputChannel = new FileOutputStream(outputFile).getChannel(); - - for (int partition = 0; partition < numPartitions; partition++) { + try { for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - long bytesToTransfer = partitionLengthInSpill; - final FileChannel spillInputChannel = spillInputChannels[i]; - while (bytesToTransfer > 0) { - final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + + long bytesWrittenToMergedFile = 0; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + long bytesToTransfer = partitionLengthInSpill; + final FileChannel spillInputChannel = spillInputChannels[i]; + while (bytesToTransfer > 0) { + final long actualBytesTransferred = spillInputChannel.transferTo( spillInputChannelPositions[i], bytesToTransfer, mergedFileOutputChannel); - spillInputChannelPositions[i] += actualBytesTransferred; - bytesToTransfer -= actualBytesTransferred; + spillInputChannelPositions[i] += actualBytesTransferred; + bytesToTransfer -= actualBytesTransferred; + } + bytesWrittenToMergedFile += partitionLengthInSpill; + partitionLengths[partition] += partitionLengthInSpill; } - partitionLengths[partition] += partitionLengthInSpill; + } + // Check the position after transferTo loop to see if it is in the right position and raise an + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." + ); + } + } finally { + for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); + if (spillInputChannels[i] != null) { + spillInputChannels[i].close(); + } + } + if (mergedFileOutputChannel != null) { + mergedFileOutputChannel.close(); } } - - // TODO: should this be in a finally block? - for (int i = 0; i < spills.length; i++) { - assert(spillInputChannelPositions[i] == spills[i].file.length()); - spillInputChannels[i].close(); - } - mergedFileOutputChannel.close(); - return partitionLengths; } @@ -215,6 +300,9 @@ public Option stop(boolean success) { stopping = true; freeMemory(); if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } return Option.apply(mapStatus); } else { // The map task failed, so delete our output data. diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 55c447327ef35..b2eb68ce9dfea 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -18,21 +18,25 @@ package org.apache.spark.shuffle.unsafe; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.util.ArrayList; -import java.util.UUID; +import java.util.*; import scala.*; import scala.runtime.AbstractFunction1; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static org.mockito.AdditionalAnswers.returnsFirstArg; import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; import org.apache.spark.*; @@ -52,18 +56,21 @@ public class UnsafeShuffleWriterSuite { + static final int NUM_PARTITITONS = 4; final TaskMemoryManager memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - // Compute key prefixes based on the records' partition ids - final HashPartitioner hashPartitioner = new HashPartitioner(4); - - ShuffleMemoryManager shuffleMemoryManager; - BlockManager blockManager; - IndexShuffleBlockManager shuffleBlockManager; - DiskBlockManager diskBlockManager; + final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); + File mergedOutputFile; File tempDir; - TaskContext taskContext; - SparkConf sparkConf; + long[] partitionSizesInMergedFile; + final LinkedList spillFilesCreated = new LinkedList(); + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockManager shuffleBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; private static final class CompressStream extends AbstractFunction1 { @Override @@ -72,26 +79,23 @@ public OutputStream apply(OutputStream stream) { } } + @After + public void tearDown() { + Utils.deleteRecursively(tempDir); + } + @Before - public void setUp() { - shuffleMemoryManager = mock(ShuffleMemoryManager.class); - diskBlockManager = mock(DiskBlockManager.class); - blockManager = mock(BlockManager.class); - shuffleBlockManager = mock(IndexShuffleBlockManager.class); - tempDir = new File(Utils.createTempDir$default$1()); - taskContext = mock(TaskContext.class); - sparkConf = new SparkConf(); - when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + @SuppressWarnings("unchecked") + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + tempDir = Utils.createTempDir("test", "test"); + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); + partitionSizesInMergedFile = null; + spillFilesCreated.clear(); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { - @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - return Tuple2$.MODULE$.apply(blockId, file); - } - }); when(blockManager.getDiskWriter( any(BlockId.class), any(File.class), @@ -115,64 +119,103 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th }); when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) .then(returnsSecondArg()); - } - - @Test - @SuppressWarnings("unchecked") - public void basicShuffleWriting() throws Exception { - final ShuffleDependency dep = mock(ShuffleDependency.class); - when(dep.serializer()).thenReturn(Option.apply(new KryoSerializer(sparkConf))); - when(dep.partitioner()).thenReturn(hashPartitioner); - - final File mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); when(shuffleBlockManager.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - final long[] partitionSizes = new long[hashPartitioner.numPartitions()]; doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - long[] receivedPartitionSizes = (long[]) invocationOnMock.getArguments()[2]; - System.arraycopy( - receivedPartitionSizes, 0, partitionSizes, 0, receivedPartitionSizes.length); + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; return null; } }).when(shuffleBlockManager).writeIndexFile(anyInt(), anyInt(), any(long[].class)); - final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer>() { + @Override + public Tuple2 answer( + InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + + when(shuffleDep.serializer()).thenReturn( + Option.apply(new KryoSerializer(new SparkConf()))); + when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + } + + private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { + SparkConf conf = new SparkConf(); + conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); + return new UnsafeShuffleWriter( blockManager, shuffleBlockManager, memoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle(0, 1, dep), + new UnsafeShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, - sparkConf + new SparkConf() ); + } - final ArrayList> numbersToSort = - new ArrayList>(); - numbersToSort.add(new Tuple2(5, 5)); - numbersToSort.add(new Tuple2(1, 1)); - numbersToSort.add(new Tuple2(3, 3)); - numbersToSort.add(new Tuple2(2, 2)); - numbersToSort.add(new Tuple2(4, 4)); + private void assertSpillFilesWereCleanedUp() { + for (File spillFile : spillFilesCreated) { + Assert.assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + + @Test(expected=IllegalStateException.class) + public void mustCallWriteBeforeSuccessfulStop() { + createWriter(false).stop(true); + } + + @Test + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() { + createWriter(false).stop(false); + } + @Test + public void writeEmptyIterator() throws Exception { + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(Collections.>emptyIterator()); + final Option mapStatus = writer.stop(true); + Assert.assertTrue(mapStatus.isDefined()); + Assert.assertTrue(mergedOutputFile.exists()); + Assert.assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); + } - writer.write(numbersToSort.iterator()); + @Test + public void writeWithoutSpilling() throws Exception { + // In this example, each partition should have exactly one record: + final ArrayList> datatToWrite = + new ArrayList>(); + for (int i = 0; i < NUM_PARTITITONS; i++) { + datatToWrite.add(new Tuple2(i, i)); + } + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(datatToWrite.iterator()); final Option mapStatus = writer.stop(true); Assert.assertTrue(mapStatus.isDefined()); + Assert.assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; - for (long size: partitionSizes) { + for (long size: partitionSizesInMergedFile) { + // All partitions should be the same size: + Assert.assertEquals(partitionSizesInMergedFile[0], size); sumOfPartitionSizes += size; } Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); - // TODO: actually try to read the shuffle output? - - // TODO: add a test that manually triggers spills in order to exercise the merging. - - // TODO: test that the temporary spill files were cleaned up after the merge. + assertSpillFilesWereCleanedUp(); } + // TODO: actually try to read the shuffle output? + // TODO: add a test that manually triggers spills in order to exercise the merging. + } From 7cd013be50add13d076c372534beaf2ff7aa3f31 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 May 2015 17:26:39 -0700 Subject: [PATCH 45/92] Begin refactoring to enable proper tests for spilling. --- .../unsafe/UnsafeShuffleExternalSorter.java | 6 +- .../shuffle/unsafe/UnsafeShuffleWriter.java | 92 ++++++++++++------- .../unsafe/UnsafeShuffleWriterSuite.java | 41 ++++++++- 3 files changed, 102 insertions(+), 37 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 70c911252fddb..8f7c3b4232691 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -21,9 +21,9 @@ import java.io.IOException; import java.util.LinkedList; -import org.apache.spark.storage.*; import scala.Tuple2; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +32,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -215,7 +216,8 @@ private SpillInfo writeSpillFile() throws IOException { /** * Sort and spill the current records in response to memory pressure. */ - private void spill() throws IOException { + @VisibleForTesting + void spill() throws IOException { final long threadId = Thread.currentThread().getId(); logger.info("Thread " + threadId + " spilling sort data of " + org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" + diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 206812f8352d2..e5a942498ae00 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -32,6 +32,7 @@ import scala.reflect.ClassTag$; import com.esotericsoftware.kryo.io.ByteBufferOutputStream; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Files; import org.slf4j.Logger; @@ -73,6 +74,11 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; private MapStatus mapStatus = null; + private UnsafeShuffleExternalSorter sorter = null; + private byte[] serArray = null; + private ByteBuffer serByteBuffer; + // TODO: we should not depend on this class from Kryo; copy its source or find an alternative + private SerializationStream serOutputStream; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -113,25 +119,18 @@ public void write(Iterator> records) { @Override public void write(scala.collection.Iterator> records) { try { - final long[] partitionLengths = mergeSpills(insertRecordsIntoSorter(records)); - shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); } catch (Exception e) { PlatformDependent.throwException(e); } } - private void freeMemory() { - // TODO - } - - private void deleteSpills() { - // TODO - } - - private SpillInfo[] insertRecordsIntoSorter( - scala.collection.Iterator> records) throws Exception { - final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter( + private void open() throws IOException { + assert (sorter == null); + sorter = new UnsafeShuffleExternalSorter( memoryManager, shuffleMemoryManager, blockManager, @@ -139,30 +138,53 @@ private SpillInfo[] insertRecordsIntoSorter( 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), sparkConf); - - final byte[] serArray = new byte[SER_BUFFER_SIZE]; - final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray); + serArray = new byte[SER_BUFFER_SIZE]; + serByteBuffer = ByteBuffer.wrap(serArray); // TODO: we should not depend on this class from Kryo; copy its source or find an alternative - final SerializationStream serOutputStream = - serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); + serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); + } - while (records.hasNext()) { - final Product2 record = records.next(); - final K key = record._1(); - final int partitionId = partitioner.getPartition(key); - serByteBuffer.position(0); - serOutputStream.writeKey(key, OBJECT_CLASS_TAG); - serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); - serOutputStream.flush(); + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + if (sorter == null) { + open(); + } + serArray = null; + serByteBuffer = null; + serOutputStream = null; + final long[] partitionLengths = mergeSpills(sorter.closeAndGetSpills()); + sorter = null; + shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } - final int serializedRecordSize = serByteBuffer.position(); - assert (serializedRecordSize > 0); + private void freeMemory() { + // TODO + } - sorter.insertRecord( - serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + @VisibleForTesting + void insertRecordIntoSorter(Product2 record) throws IOException{ + if (sorter == null) { + open(); } + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serByteBuffer.position(0); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); - return sorter.closeAndGetSpills(); + final int serializedRecordSize = serByteBuffer.position(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); } private long[] mergeSpills(SpillInfo[] spills) throws IOException { @@ -222,6 +244,9 @@ private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) th for (int i = 0; i < spills.length; i++) { if (spillInputStreams[i] != null) { spillInputStreams[i].close(); + if (!spills[i].file.delete()) { + logger.error("Error while deleting spill file {}", spills[i]); + } } } if (mergedFileOutputStream != null) { @@ -282,6 +307,9 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th assert(spillInputChannelPositions[i] == spills[i].file.length()); if (spillInputChannels[i] != null) { spillInputChannels[i].close(); + if (!spills[i].file.delete()) { + logger.error("Error while deleting spill file {}", spills[i]); + } } } if (mergedFileOutputChannel != null) { diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index b2eb68ce9dfea..09eb537c04367 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -193,13 +193,13 @@ public void writeEmptyIterator() throws Exception { @Test public void writeWithoutSpilling() throws Exception { // In this example, each partition should have exactly one record: - final ArrayList> datatToWrite = + final ArrayList> dataToWrite = new ArrayList>(); for (int i = 0; i < NUM_PARTITITONS; i++) { - datatToWrite.add(new Tuple2(i, i)); + dataToWrite.add(new Tuple2(i, i)); } final UnsafeShuffleWriter writer = createWriter(true); - writer.write(datatToWrite.iterator()); + writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); Assert.assertTrue(mapStatus.isDefined()); Assert.assertTrue(mergedOutputFile.exists()); @@ -215,7 +215,42 @@ public void writeWithoutSpilling() throws Exception { assertSpillFilesWereCleanedUp(); } + private void testMergingSpills(boolean transferToEnabled) throws IOException { + final UnsafeShuffleWriter writer = createWriter(true); + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.insertRecordIntoSorter(new Tuple2(3, 3)); + writer.insertRecordIntoSorter(new Tuple2(4, 4)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(new Tuple2(4, 4)); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.closeAndWriteOutput(); + final Option mapStatus = writer.stop(true); + Assert.assertTrue(mapStatus.isDefined()); + Assert.assertTrue(mergedOutputFile.exists()); + Assert.assertEquals(2, spillFilesCreated.size()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + + assertSpillFilesWereCleanedUp(); + } + + @Test + public void mergeSpillsWithTransferTo() throws Exception { + testMergingSpills(true); + } + + @Test + public void mergeSpillsWithFileStream() throws Exception { + testMergingSpills(false); + } + // TODO: actually try to read the shuffle output? // TODO: add a test that manually triggers spills in order to exercise the merging. +// } } From 9b7ebed6d699cdc7de899729dbca27e4f9341983 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 9 May 2015 09:35:05 -0700 Subject: [PATCH 46/92] More defensive programming RE: cleaning up spill files and memory after errors --- .../unsafe/UnsafeShuffleExternalSorter.java | 29 ++++++++-- .../shuffle/unsafe/UnsafeShuffleWriter.java | 53 ++++++++----------- .../unsafe/UnsafeShuffleWriterSuite.java | 23 ++++++-- 3 files changed, 66 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 8f7c3b4232691..6ca3e09e3e439 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -252,6 +252,22 @@ private long freeMemory() { return memoryFreed; } + /** + * Force all memory and spill files to be deleted; called by shuffle error-handling code. + */ + public void cleanupAfterError() { + freeMemory(); + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Unable to delete spill file {}", spill.file.getPath()); + } + } + if (spillingEnabled && sorter != null) { + shuffleMemoryManager.release(sorter.getMemoryUsage()); + sorter = null; + } + } + /** * Checks whether there is enough space to insert a new record into the sorter. * @@ -362,11 +378,16 @@ public void insertRecord( * @throws IOException */ public SpillInfo[] closeAndGetSpills() throws IOException { - if (sorter != null) { - writeSpillFile(); - freeMemory(); + try { + if (sorter != null) { + writeSpillFile(); + freeMemory(); + } + return spills.toArray(new SpillInfo[spills.size()]); + } catch (IOException e) { + cleanupAfterError(); + throw e; } - return spills.toArray(new SpillInfo[0]); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index e5a942498ae00..70afea553556c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -17,10 +17,7 @@ package org.apache.spark.shuffle.unsafe; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.Iterator; @@ -34,6 +31,7 @@ import com.esotericsoftware.kryo.io.ByteBufferOutputStream; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -152,16 +150,22 @@ void closeAndWriteOutput() throws IOException { serArray = null; serByteBuffer = null; serOutputStream = null; - final long[] partitionLengths = mergeSpills(sorter.closeAndGetSpills()); + final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; + final long[] partitionLengths; + try { + partitionLengths = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } + } + } shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } - private void freeMemory() { - // TODO - } - @VisibleForTesting void insertRecordIntoSorter(Product2 record) throws IOException{ if (sorter == null) { @@ -241,17 +245,10 @@ private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) th } } } finally { - for (int i = 0; i < spills.length; i++) { - if (spillInputStreams[i] != null) { - spillInputStreams[i].close(); - if (!spills[i].file.delete()) { - logger.error("Error while deleting spill file {}", spills[i]); - } - } - } - if (mergedFileOutputStream != null) { - mergedFileOutputStream.close(); + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, false); } + Closeables.close(mergedFileOutputStream, false); } return partitionLengths; } @@ -305,16 +302,9 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th } finally { for (int i = 0; i < spills.length; i++) { assert(spillInputChannelPositions[i] == spills[i].file.length()); - if (spillInputChannels[i] != null) { - spillInputChannels[i].close(); - if (!spills[i].file.delete()) { - logger.error("Error while deleting spill file {}", spills[i]); - } - } - } - if (mergedFileOutputChannel != null) { - mergedFileOutputChannel.close(); + Closeables.close(spillInputChannels[i], false); } + Closeables.close(mergedFileOutputChannel, false); } return partitionLengths; } @@ -326,7 +316,6 @@ public Option stop(boolean success) { return Option.apply(null); } else { stopping = true; - freeMemory(); if (success) { if (mapStatus == null) { throw new IllegalStateException("Cannot call stop(true) without having called write()"); @@ -339,7 +328,11 @@ public Option stop(boolean success) { } } } finally { - freeMemory(); + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + sorter.cleanupAfterError(); + } } } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 09eb537c04367..a1d654c9d121e 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -57,7 +57,7 @@ public class UnsafeShuffleWriterSuite { static final int NUM_PARTITITONS = 4; - final TaskMemoryManager memoryManager = + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); File mergedOutputFile; @@ -82,6 +82,10 @@ public OutputStream apply(OutputStream stream) { @After public void tearDown() { Utils.deleteRecursively(tempDir); + final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (leakedMemory != 0) { + Assert.fail("Test leaked " + leakedMemory + " bytes of managed memory"); + } } @Before @@ -154,7 +158,7 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabl return new UnsafeShuffleWriter( blockManager, shuffleBlockManager, - memoryManager, + taskMemoryManager, shuffleMemoryManager, new UnsafeShuffleHandle(0, 1, shuffleDep), 0, // map id @@ -216,7 +220,7 @@ public void writeWithoutSpilling() throws Exception { } private void testMergingSpills(boolean transferToEnabled) throws IOException { - final UnsafeShuffleWriter writer = createWriter(true); + final UnsafeShuffleWriter writer = createWriter(transferToEnabled); writer.insertRecordIntoSorter(new Tuple2(1, 1)); writer.insertRecordIntoSorter(new Tuple2(2, 2)); writer.insertRecordIntoSorter(new Tuple2(3, 3)); @@ -249,8 +253,17 @@ public void mergeSpillsWithFileStream() throws Exception { testMergingSpills(false); } + @Test + public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { + final UnsafeShuffleWriter writer = createWriter(false); + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.stop(false); + assertSpillFilesWereCleanedUp(); + } + // TODO: actually try to read the shuffle output? - // TODO: add a test that manually triggers spills in order to exercise the merging. -// } } From 1929a7439040e7c2439e386eb0994236e07c6780 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 9 May 2015 09:37:48 -0700 Subject: [PATCH 47/92] Update to reflect upstream ShuffleBlockManager -> ShuffleBlockResolver rename. --- .../spark/shuffle/unsafe/UnsafeShuffleWriter.java | 14 +++++++------- .../shuffle/unsafe/UnsafeShuffleManager.scala | 2 +- .../shuffle/unsafe/UnsafeShuffleWriterSuite.java | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 70afea553556c..b442162946afb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -44,7 +44,7 @@ import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.IndexShuffleBlockManager; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; @@ -59,7 +59,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); private final BlockManager blockManager; - private final IndexShuffleBlockManager shuffleBlockManager; + private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final SerializerInstance serializer; @@ -87,7 +87,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { public UnsafeShuffleWriter( BlockManager blockManager, - IndexShuffleBlockManager shuffleBlockManager, + IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, UnsafeShuffleHandle handle, @@ -95,7 +95,7 @@ public UnsafeShuffleWriter( TaskContext taskContext, SparkConf sparkConf) { this.blockManager = blockManager; - this.shuffleBlockManager = shuffleBlockManager; + this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.mapId = mapId; @@ -162,7 +162,7 @@ void closeAndWriteOutput() throws IOException { } } } - shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @@ -192,7 +192,7 @@ void forceSorterToSpill() throws IOException { } private long[] mergeSpills(SpillInfo[] spills) throws IOException { - final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); + final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -323,7 +323,7 @@ public Option stop(boolean success) { return Option.apply(mapStatus); } else { // The map task failed, so delete our output data. - shuffleBlockManager.removeDataByMap(shuffleId, mapId); + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); return Option.apply(null); } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 14f29a36ec4f6..994a8c049a331 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -92,7 +92,7 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage // TODO: do we need to do anything to register the shuffle here? new UnsafeShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockManager], + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], context.taskMemoryManager(), env.shuffleMemoryManager, unsafeShuffleHandle, diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index a1d654c9d121e..ac94161d9f242 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -41,7 +41,7 @@ import org.apache.spark.*; import org.apache.spark.serializer.Serializer; -import org.apache.spark.shuffle.IndexShuffleBlockManager; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.serializer.SerializerInstance; @@ -67,7 +67,7 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; - @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockManager shuffleBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; @@ -124,14 +124,14 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) .then(returnsSecondArg()); - when(shuffleBlockManager.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; return null; } - }).when(shuffleBlockManager).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer( new Answer>() { @@ -157,7 +157,7 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabl conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter( blockManager, - shuffleBlockManager, + shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, new UnsafeShuffleHandle(0, 1, shuffleDep), From 01afc74a3828f39d0ac54b93780ee16971d79828 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 9 May 2015 17:10:54 -0700 Subject: [PATCH 48/92] Actually read data in UnsafeShuffleWriterSuite --- .../unsafe/UnsafeShuffleWriterSuite.java | 71 +++++++++++++------ 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index ac94161d9f242..b995ef826aaf4 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -17,15 +17,15 @@ package org.apache.spark.shuffle.unsafe; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; +import java.io.*; import java.util.*; import scala.*; +import scala.collection.Iterator; import scala.runtime.AbstractFunction1; +import com.google.common.collect.HashMultiset; +import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -40,19 +40,21 @@ import static org.mockito.Mockito.*; import org.apache.spark.*; -import org.apache.spark.serializer.Serializer; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.serializer.DeserializationStream; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; -import org.apache.spark.serializer.KryoSerializer; -import org.apache.spark.scheduler.MapStatus; public class UnsafeShuffleWriterSuite { @@ -64,6 +66,7 @@ public class UnsafeShuffleWriterSuite { File tempDir; long[] partitionSizesInMergedFile; final LinkedList spillFilesCreated = new LinkedList(); + final Serializer serializer = new KryoSerializer(new SparkConf()); @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -147,8 +150,7 @@ public Tuple2 answer( when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); - when(shuffleDep.serializer()).thenReturn( - Option.apply(new KryoSerializer(new SparkConf()))); + when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); } @@ -174,6 +176,27 @@ private void assertSpillFilesWereCleanedUp() { } } + private List> readRecordsFromFile() throws IOException { + final ArrayList> recordsList = new ArrayList>(); + long startOffset = 0; + for (int i = 0; i < NUM_PARTITITONS; i++) { + final long partitionSize = partitionSizesInMergedFile[i]; + if (partitionSize > 0) { + InputStream in = new FileInputStream(mergedOutputFile); + ByteStreams.skipFully(in, startOffset); + DeserializationStream recordsStream = serializer.newInstance().deserializeStream( + new LimitedInputStream(in, partitionSize)); + Iterator> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + recordsList.add(records.next()); + } + recordsStream.close(); + startOffset += partitionSize; + } + } + return recordsList; + } + @Test(expected=IllegalStateException.class) public void mustCallWriteBeforeSuccessfulStop() { createWriter(false).stop(true); @@ -215,19 +238,26 @@ public void writeWithoutSpilling() throws Exception { sumOfPartitionSizes += size; } Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); - + Assert.assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } private void testMergingSpills(boolean transferToEnabled) throws IOException { final UnsafeShuffleWriter writer = createWriter(transferToEnabled); - writer.insertRecordIntoSorter(new Tuple2(1, 1)); - writer.insertRecordIntoSorter(new Tuple2(2, 2)); - writer.insertRecordIntoSorter(new Tuple2(3, 3)); - writer.insertRecordIntoSorter(new Tuple2(4, 4)); + final ArrayList> dataToWrite = + new ArrayList>(); + for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { + dataToWrite.add(new Tuple2(i, i)); + } + writer.insertRecordIntoSorter(dataToWrite.get(0)); + writer.insertRecordIntoSorter(dataToWrite.get(1)); + writer.insertRecordIntoSorter(dataToWrite.get(2)); + writer.insertRecordIntoSorter(dataToWrite.get(3)); writer.forceSorterToSpill(); - writer.insertRecordIntoSorter(new Tuple2(4, 4)); - writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.insertRecordIntoSorter(dataToWrite.get(4)); + writer.insertRecordIntoSorter(dataToWrite.get(5)); writer.closeAndWriteOutput(); final Option mapStatus = writer.stop(true); Assert.assertTrue(mapStatus.isDefined()); @@ -239,7 +269,9 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException { sumOfPartitionSizes += size; } Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); - + Assert.assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } @@ -263,7 +295,4 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { writer.stop(false); assertSpillFilesWereCleanedUp(); } - - // TODO: actually try to read the shuffle output? - } From 8f5061adb3ed8c4605c792ab37455e42e8906a62 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 9 May 2015 17:16:14 -0700 Subject: [PATCH 49/92] Strengthen assertion to check partitioning --- .../apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index b995ef826aaf4..cc8cbc534510b 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -188,7 +188,9 @@ private List> readRecordsFromFile() throws IOException { new LimitedInputStream(in, partitionSize)); Iterator> records = recordsStream.asKeyValueIterator(); while (records.hasNext()) { - recordsList.add(records.next()); + Tuple2 record = records.next(); + Assert.assertEquals(i, hashPartitioner.getPartition(record._1())); + recordsList.add(record); } recordsStream.close(); startOffset += partitionSize; From 67d25ba1be72b0326e7f3f2b319fb4b7b897f0b3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 9 May 2015 17:31:21 -0700 Subject: [PATCH 50/92] Update Exchange operator's copying logic to account for new shuffle manager --- .../apache/spark/sql/execution/Exchange.scala | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index c3d2c7019a54a..3e46596ecf6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.util.MutablePair object Exchange { @@ -85,7 +86,9 @@ case class Exchange( // corner-cases where a partitioner constructed with `numPartitions` partitions may output // fewer partitions (like RangePartitioner, for example). val conf = child.sqlContext.sparkContext.conf - val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + val shuffleManager = SparkEnv.get.shuffleManager + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] || + shuffleManager.isInstanceOf[UnsafeShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) if (newOrdering.nonEmpty) { @@ -93,11 +96,11 @@ case class Exchange( // which requires a defensive copy. true } else if (sortBasedShuffleOn) { - // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory. - // However, there are two special cases where we can avoid the copy, described below: - if (partitioner.numPartitions <= bypassMergeThreshold) { - // If the number of output partitions is sufficiently small, then Spark will fall back to - // the old hash-based shuffle write path which doesn't buffer deserialized records. + val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + // If we're using the original SortShuffleManager and the number of output partitions is + // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which + // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { @@ -105,9 +108,14 @@ case class Exchange( // them. This optimization is guarded by a feature-flag and is only applied in cases where // shuffle dependency does not specify an ordering and the record serializer has certain // properties. If this optimization is enabled, we can safely avoid the copy. + // + // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false } else { - // None of the special cases held, so we must copy. + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code + // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls + // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In + // both cases, we must copy. true } } else { From fd4bb9e6bf8c0ef824c2515665abc3156c6d9963 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 9 May 2015 18:07:44 -0700 Subject: [PATCH 51/92] Use own ByteBufferOutputStream rather than Kryo's --- .../unsafe/ByteBufferOutputStream.java | 46 +++++++++++++++++++ .../unsafe/UnsafeShuffleExternalSorter.java | 2 +- .../shuffle/unsafe/UnsafeShuffleSorter.java | 2 +- .../shuffle/unsafe/UnsafeShuffleWriter.java | 2 - 4 files changed, 48 insertions(+), 4 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/ByteBufferOutputStream.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/ByteBufferOutputStream.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/ByteBufferOutputStream.java new file mode 100644 index 0000000000000..3410cd2911ebe --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/ByteBufferOutputStream.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +class ByteBufferOutputStream extends OutputStream { + + private final ByteBuffer byteBuffer; + + public ByteBufferOutputStream(ByteBuffer byteBuffer) { + this.byteBuffer = byteBuffer; + } + + @Override + public void write(int b) throws IOException { + byteBuffer.put((byte) b); + } + + @Override + public void write(byte[] b) throws IOException { + byteBuffer.put(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + byteBuffer.put(b, off, len); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 6ca3e09e3e439..892a78796335b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -53,7 +53,7 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -public final class UnsafeShuffleExternalSorter { +final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java index d9ffe9a44fec7..d15da8a7ee126 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java @@ -21,7 +21,7 @@ import org.apache.spark.util.collection.Sorter; -public final class UnsafeShuffleSorter { +final class UnsafeShuffleSorter { private final Sorter sorter; private static final class SortComparator implements Comparator { diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index b442162946afb..d20c19547adf1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -28,7 +28,6 @@ import scala.reflect.ClassTag; import scala.reflect.ClassTag$; -import com.esotericsoftware.kryo.io.ByteBufferOutputStream; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; @@ -75,7 +74,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private UnsafeShuffleExternalSorter sorter = null; private byte[] serArray = null; private ByteBuffer serByteBuffer; - // TODO: we should not depend on this class from Kryo; copy its source or find an alternative private SerializationStream serOutputStream; /** From 9d1ee7cac3effba0c264f93fe2303f427ac26cdc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 9 May 2015 18:14:26 -0700 Subject: [PATCH 52/92] Fix MiMa excludes for ShuffleWriter change --- .../apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java | 1 - project/MimaExcludes.scala | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index d20c19547adf1..f28e63f137bc9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -136,7 +136,6 @@ private void open() throws IOException { sparkConf); serArray = new byte[SER_BUFFER_SIZE]; serByteBuffer = ByteBuffer.wrap(serArray); - // TODO: we should not depend on this class from Kryo; copy its source or find an alternative serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index cfe387faec14b..6913ebbae9714 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -106,6 +106,12 @@ object MimaExcludes { "org.apache.spark.sql.parquet.ParquetTestData$"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.TestGroupWriteSupport") + ) ++ Seq( + // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some + // unnecessary type bounds in order to fix some compiler warnings that occurred when + // implementing this interface in Java. Note that ShuffleWriter is private[spark]. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.shuffle.ShuffleWriter") ) case v if v.startsWith("1.3") => From fcd9a3c499eac4123139276cbef713b8f5d9471a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 14:07:51 -0700 Subject: [PATCH 53/92] Add notes + tests for maximum record / page sizes. --- .../shuffle/unsafe/PackedRecordPointer.java | 24 +++++++-- .../unsafe/UnsafeShuffleExternalSorter.java | 51 +++++++++++-------- .../unsafe/UnsafeShuffleSortDataFormat.java | 2 +- .../shuffle/unsafe/UnsafeShuffleSorter.java | 2 +- .../shuffle/unsafe/UnsafeShuffleWriter.java | 18 +++++-- .../apache/spark/shuffle/ShuffleWriter.scala | 3 ++ .../unsafe/PackedRecordPointerSuite.java | 16 +++--- .../unsafe/UnsafeShuffleWriterSuite.java | 37 ++++++++++++++ 8 files changed, 113 insertions(+), 40 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java index 34c15e6bbcb0e..8c0940d23420b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -19,9 +19,24 @@ /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. + *

+ * Within the long, the data is laid out as follows: + *

+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * 
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that + * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the + * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. + *

+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this + * optimization to future work as it will require more careful design to ensure that addresses are + * properly aligned (e.g. by padding records). */ final class PackedRecordPointer { + static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + /** Bit mask for the lower 40 bits of a long. */ private static final long MASK_LONG_LOWER_40_BITS = 0xFFFFFFFFFFL; @@ -55,7 +70,11 @@ public static long packPointer(long recordPointer, int partitionId) { return (((long) partitionId) << 40) | compressedAddress; } - public long packedRecordPointer; + private long packedRecordPointer; + + public void set(long packedRecordPointer) { + this.packedRecordPointer = packedRecordPointer; + } public int getPartitionId() { return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40); @@ -68,7 +87,4 @@ public long getRecordPointer() { return pageNumber | offsetInPage; } - public int getRecordLength() { - return -1; // TODO - } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 892a78796335b..6e0d8da410231 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -57,8 +57,9 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); - private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this / don't duplicate - private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this + @VisibleForTesting + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; private final int initialSize; private final int numPartitions; @@ -88,13 +89,13 @@ final class UnsafeShuffleExternalSorter { private long freeSpaceInCurrentPage = 0; public UnsafeShuffleExternalSorter( - TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, - BlockManager blockManager, - TaskContext taskContext, - int initialSize, - int numPartitions, - SparkConf conf) throws IOException { + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; @@ -140,8 +141,9 @@ private SpillInfo writeSpillFile() throws IOException { // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer - // records in a byte array. This array only needs to be big enough to hold a single record. - final byte[] arr = new byte[SER_BUFFER_SIZE]; + // data through a byte array. This array does not need to be large enough to hold a single + // record; + final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; // Because this output will be read during shuffle, its compression codec must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -186,16 +188,23 @@ private SpillInfo writeSpillFile() throws IOException { } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final int recordLength = PlatformDependent.UNSAFE.getInt( - memoryManager.getPage(recordPointer), memoryManager.getOffsetInPage(recordPointer)); - PlatformDependent.copyMemory( - memoryManager.getPage(recordPointer), - memoryManager.getOffsetInPage(recordPointer) + 4, // skip over record length - arr, - PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength); - assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, recordLength); + final Object recordPage = memoryManager.getPage(recordPointer); + final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); + int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + 4; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + PlatformDependent.copyMemory( + recordPage, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer); + assert (writer != null); // To suppress an IntelliJ warning + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java index 862845180584e..a66d74ee44782 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -38,7 +38,7 @@ public PackedRecordPointer newKey() { @Override public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { - reuse.packedRecordPointer = data[pos]; + reuse.set(data[pos]); return reuse; } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java index d15da8a7ee126..5acbc6c1c4f2f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java @@ -95,7 +95,7 @@ public boolean hasNext() { @Override public void loadNext() { - packedRecordPointer.packedRecordPointer = sortBuffer[position]; + packedRecordPointer.set(sortBuffer[position]); position++; } }; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index f28e63f137bc9..db9f8648a93b4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -54,7 +54,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); - private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + @VisibleForTesting + static final int MAXIMUM_RECORD_SIZE = 1024 * 1024 * 64; // 64 megabytes private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); private final BlockManager blockManager; @@ -108,19 +109,26 @@ public UnsafeShuffleWriter( this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); } - public void write(Iterator> records) { + public void write(Iterator> records) throws IOException { write(JavaConversions.asScalaIterator(records)); } @Override - public void write(scala.collection.Iterator> records) { + public void write(scala.collection.Iterator> records) throws IOException { try { while (records.hasNext()) { insertRecordIntoSorter(records.next()); } closeAndWriteOutput(); } catch (Exception e) { - PlatformDependent.throwException(e); + // Unfortunately, we have to catch Exception here in order to ensure proper cleanup after + // errors becuase Spark's Scala code, or users' custom Serializers, might throw arbitrary + // unchecked exceptions. + try { + sorter.cleanupAfterError(); + } finally { + throw new IOException("Error during shuffle write", e); + } } } @@ -134,7 +142,7 @@ private void open() throws IOException { 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), sparkConf); - serArray = new byte[SER_BUFFER_SIZE]; + serArray = new byte[MAXIMUM_RECORD_SIZE]; serByteBuffer = ByteBuffer.wrap(serArray); serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index e28a2459cdff9..4cc4ef5f1886e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle +import java.io.IOException + import org.apache.spark.scheduler.MapStatus /** @@ -24,6 +26,7 @@ import org.apache.spark.scheduler.MapStatus */ private[spark] abstract class ShuffleWriter[K, V] { /** Write a sequence of records to this task's output */ + @throws[IOException] def write(records: Iterator[Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java index 53554520b22b1..ba1f89d099838 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -34,10 +34,10 @@ public void heap() { final MemoryBlock page0 = memoryManager.allocatePage(100); final MemoryBlock page1 = memoryManager.allocatePage(100); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); - PackedRecordPointer packedPointerWrapper = new PackedRecordPointer(); - packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360); - Assert.assertEquals(360, packedPointerWrapper.getPartitionId()); - Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer()); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + Assert.assertEquals(360, packedPointer.getPartitionId()); + Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer()); memoryManager.cleanUpAllAllocatedMemory(); } @@ -48,10 +48,10 @@ public void offHeap() { final MemoryBlock page0 = memoryManager.allocatePage(100); final MemoryBlock page1 = memoryManager.allocatePage(100); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); - PackedRecordPointer packedPointerWrapper = new PackedRecordPointer(); - packedPointerWrapper.packedRecordPointer = PackedRecordPointer.packPointer(addressInPage1, 360); - Assert.assertEquals(360, packedPointerWrapper.getPartitionId()); - Assert.assertEquals(addressInPage1, packedPointerWrapper.getRecordPointer()); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + Assert.assertEquals(360, packedPointer.getPartitionId()); + Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer()); memoryManager.cleanUpAllAllocatedMemory(); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index cc8cbc534510b..9002126bb7a4a 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.unsafe; import java.io.*; +import java.nio.ByteBuffer; import java.util.*; import scala.*; @@ -287,6 +288,42 @@ public void mergeSpillsWithFileStream() throws Exception { testMergingSpills(false); } + @Test + public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = + new ArrayList>(); + final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + new Random(42).nextBytes(bytes); + dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + Assert.assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void writeRecordsThatAreBiggerThanMaximumRecordSize() throws Exception { + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = + new ArrayList>(); + final byte[] bytes = new byte[UnsafeShuffleWriter.MAXIMUM_RECORD_SIZE * 2]; + new Random(42).nextBytes(bytes); + dataToWrite.add(new Tuple2(1, bytes)); + try { + // Insert a record and force a spill so that there's something to clean up: + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + writer.forceSorterToSpill(); + writer.write(dataToWrite.iterator()); + Assert.fail("Expected exception to be thrown"); + } catch (IOException e) { + // Pass + } + assertSpillFilesWereCleanedUp(); + } + @Test public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { final UnsafeShuffleWriter writer = createWriter(false); From 27b18b09aca66cb2dac8f779701569200deba43a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 14:37:53 -0700 Subject: [PATCH 54/92] That for inserting records AT the max record size. --- .../unsafe/UnsafeShuffleExternalSorter.java | 4 +- .../shuffle/unsafe/UnsafeShuffleWriter.java | 25 ++++--- .../unsafe/UnsafeShuffleWriterSuite.java | 67 +++++++++++++++---- 3 files changed, 70 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 6e0d8da410231..c9d818034c899 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -57,9 +57,11 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; private final int initialSize; private final int numPartitions; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index db9f8648a93b4..5bf04617854bb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -18,7 +18,6 @@ package org.apache.spark.shuffle.unsafe; import java.io.*; -import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.Iterator; @@ -73,8 +72,14 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private MapStatus mapStatus = null; private UnsafeShuffleExternalSorter sorter = null; - private byte[] serArray = null; - private ByteBuffer serByteBuffer; + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { + public MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } + } + + private MyByteArrayOutputStream serBuffer; private SerializationStream serOutputStream; /** @@ -142,9 +147,8 @@ private void open() throws IOException { 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), sparkConf); - serArray = new byte[MAXIMUM_RECORD_SIZE]; - serByteBuffer = ByteBuffer.wrap(serArray); - serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); + serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serOutputStream = serializer.serializeStream(serBuffer); } @VisibleForTesting @@ -152,8 +156,7 @@ void closeAndWriteOutput() throws IOException { if (sorter == null) { open(); } - serArray = null; - serByteBuffer = null; + serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; @@ -178,16 +181,16 @@ void insertRecordIntoSorter(Product2 record) throws IOException{ } final K key = record._1(); final int partitionId = partitioner.getPartition(key); - serByteBuffer.position(0); + serBuffer.reset(); serOutputStream.writeKey(key, OBJECT_CLASS_TAG); serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); serOutputStream.flush(); - final int serializedRecordSize = serByteBuffer.position(); + final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); sorter.insertRecord( - serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 9002126bb7a4a..48ba85f917b87 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -23,6 +23,7 @@ import scala.*; import scala.collection.Iterator; +import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; import com.google.common.collect.HashMultiset; @@ -44,11 +45,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.serializer.*; import org.apache.spark.scheduler.MapStatus; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.KryoSerializer; -import org.apache.spark.serializer.Serializer; -import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; @@ -305,18 +303,59 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception } @Test - public void writeRecordsThatAreBiggerThanMaximumRecordSize() throws Exception { + public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { + // Use a custom serializer so that we have exact control over the size of serialized data. + final Serializer byteArraySerializer = new Serializer() { + @Override + public SerializerInstance newInstance() { + return new SerializerInstance() { + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + byte[] bytes = (byte[]) t; + try { + s.write(bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + @Override + public void close() { } + }; + } + public ByteBuffer serialize(T t, ClassTag ev1) { return null; } + public DeserializationStream deserializeStream(InputStream s) { return null; } + public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; } + public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; } + }; + } + }; + when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer)); final UnsafeShuffleWriter writer = createWriter(false); - final ArrayList> dataToWrite = - new ArrayList>(); - final byte[] bytes = new byte[UnsafeShuffleWriter.MAXIMUM_RECORD_SIZE * 2]; - new Random(42).nextBytes(bytes); - dataToWrite.add(new Tuple2(1, bytes)); + // Insert a record and force a spill so that there's something to clean up: + writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); + writer.forceSorterToSpill(); + // We should be able to write a record that's right _at_ the max record size + final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + new Random(42).nextBytes(atMaxRecordSize); + writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); + writer.forceSorterToSpill(); + // Inserting a record that's larger than the max record size should fail: + final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + new Random(42).nextBytes(exceedsMaxRecordSize); + Product2 hugeRecord = + new Tuple2(new byte[0], exceedsMaxRecordSize); try { - // Insert a record and force a spill so that there's something to clean up: - writer.insertRecordIntoSorter(new Tuple2(1, 1)); - writer.forceSorterToSpill(); - writer.write(dataToWrite.iterator()); + // Here, we write through the public `write()` interface instead of the test-only + // `insertRecordIntoSorter` interface: + writer.write(Collections.singletonList(hugeRecord).iterator()); Assert.fail("Expected exception to be thrown"); } catch (IOException e) { // Pass From 4a01c45aaf2dcf1bd4dd3ebe632a5493e75b40d7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 16:13:24 -0700 Subject: [PATCH 55/92] Remove unnecessary log message --- .../apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index c9d818034c899..ccc1018a71168 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -288,7 +288,6 @@ public void cleanupAfterError() { * @return true if the record can be inserted without requiring more allocations, false otherwise. */ private boolean haveSpaceForRecord(int requiredSpace) { - logger.warn("Seeing if there's space for the record"); assert (requiredSpace > 0); // The sort array will automatically expand when inserting a new record, so we only need to // worry about it having free space when spilling is enabled. From f780fb1c19498246c1de3a86e8e7816359bf4069 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 17:04:41 -0700 Subject: [PATCH 56/92] Add test demonstrating which compression codecs support concatenation. --- .../spark/io/CompressionCodecSuite.scala | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 8c6035fb367fe..cf6a143537889 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import com.google.common.io.ByteStreams import org.scalatest.FunSuite import org.apache.spark.SparkConf @@ -62,6 +63,14 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lz4 does not support concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName) + assert(codec.getClass === classOf[LZ4CompressionCodec]) + intercept[Exception] { + testConcatenationOfSerializedStreams(codec) + } + } + test("lzf compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) @@ -74,6 +83,12 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lzf supports concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) + assert(codec.getClass === classOf[LZFCompressionCodec]) + testConcatenationOfSerializedStreams(codec) + } + test("snappy compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) @@ -86,9 +101,38 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("snappy does not support concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) + assert(codec.getClass === classOf[SnappyCompressionCodec]) + intercept[Exception] { + testConcatenationOfSerializedStreams(codec) + } + } + test("bad compression codec") { intercept[IllegalArgumentException] { CompressionCodec.createCodec(conf, "foobar") } } + + private def testConcatenationOfSerializedStreams(codec: CompressionCodec): Unit = { + val bytes1: Array[Byte] = { + val baos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(baos) + (0 to 64).foreach(out.write) + out.close() + baos.toByteArray + } + val bytes2: Array[Byte] = { + val baos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(baos) + (65 to 127).foreach(out.write) + out.close() + baos.toByteArray + } + val concatenatedBytes = codec.compressedInputStream(new ByteArrayInputStream(bytes1 ++ bytes2)) + val decompressed: Array[Byte] = new Array[Byte](128) + ByteStreams.readFully(concatenatedBytes, decompressed) + assert(decompressed.toSeq === (0 to 127)) + } } From b57c17f61dc56182c60d4eaa0bfebe38f2b9d04a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 19:54:18 -0700 Subject: [PATCH 57/92] Disable some overly-verbose logs that rendered DEBUG useless. --- .../spark/unsafe/memory/TaskMemoryManager.java | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 983bd99f0c4ce..2aacf637eb6a4 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -101,9 +101,6 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (logger.isTraceEnabled()) { - logger.trace("Allocating {} byte page", size); - } if (size >= (1L << 51)) { throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); } @@ -120,8 +117,8 @@ public MemoryBlock allocatePage(long size) { final MemoryBlock page = executorMemoryManager.allocate(size); page.pageNumber = pageNumber; pageTable[pageNumber] = page; - if (logger.isDebugEnabled()) { - logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); } return page; } @@ -130,9 +127,6 @@ public MemoryBlock allocatePage(long size) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. */ public void freePage(MemoryBlock page) { - if (logger.isTraceEnabled()) { - logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); - } assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; executorMemoryManager.free(page); @@ -140,8 +134,8 @@ public void freePage(MemoryBlock page) { allocatedPages.clear(page.pageNumber); } pageTable[page.pageNumber] = null; - if (logger.isDebugEnabled()) { - logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } } From 1ef56c77b62d2413130dd975555400a0d138324b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 19:55:49 -0700 Subject: [PATCH 58/92] Revise compression codec support in merger; test cross product of configurations. --- .../spark/shuffle/unsafe/SpillInfo.java | 6 +- .../unsafe/UnsafeShuffleExternalSorter.java | 4 +- .../shuffle/unsafe/UnsafeShuffleWriter.java | 62 ++++++--- .../unsafe/UnsafeShuffleWriterSuite.java | 122 ++++++++++++++---- 4 files changed, 151 insertions(+), 43 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java index 5435c2c98428f..5d13354231491 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.unsafe; -import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.TempShuffleBlockId; import java.io.File; @@ -27,9 +27,9 @@ final class SpillInfo { final long[] partitionLengths; final File file; - final BlockId blockId; + final TempShuffleBlockId blockId; - public SpillInfo(int numPartitions, File file, BlockId blockId) { + public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { this.partitionLengths = new long[numPartitions]; this.file = file; this.blockId = blockId; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index ccc1018a71168..3cf99307c47cc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -153,7 +153,7 @@ private SpillInfo writeSpillFile() throws IOException { final Tuple2 spilledFileInfo = blockManager.diskBlockManager().createTempShuffleBlock(); final File file = spilledFileInfo._2(); - final BlockId blockId = spilledFileInfo._1(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. @@ -320,7 +320,7 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { } } if (requiredSpace > freeSpaceInCurrentPage) { - logger.debug("Required space {} is less than free space in current page ({}}", requiredSpace, + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, freeSpaceInCurrentPage); // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 5bf04617854bb..df05a95506f4b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -20,6 +20,7 @@ import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; +import javax.annotation.Nullable; import scala.Option; import scala.Product2; @@ -35,6 +36,9 @@ import org.slf4j.LoggerFactory; import org.apache.spark.*; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZFCompressionCodec; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -53,8 +57,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); - @VisibleForTesting - static final int MAXIMUM_RECORD_SIZE = 1024 * 1024 * 64; // 64 megabytes private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); private final BlockManager blockManager; @@ -201,6 +203,12 @@ void forceSorterToSpill() throws IOException { private long[] mergeSpills(SpillInfo[] spills) throws IOException { final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); + final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); + final boolean fastMergeEnabled = + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + final boolean fastMergeIsSupported = + !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -215,11 +223,20 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { Files.move(spills[0].file, outputFile); return spills[0].partitionLengths; } else { - // Need to merge multiple spills. - if (transferToEnabled) { - return mergeSpillsWithTransferTo(spills, outputFile); + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled) { + logger.debug("Using transferTo-based fast merge"); + return mergeSpillsWithTransferTo(spills, outputFile); + } else { + logger.debug("Using fileStream-based fast merge"); + return mergeSpillsWithFileStream(spills, outputFile, null); + } } else { - return mergeSpillsWithFileStream(spills, outputFile); + logger.debug("Using slow merge"); + return mergeSpillsWithFileStream(spills, outputFile, compressionCodec); } } } catch (IOException e) { @@ -230,27 +247,40 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { } } - private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) throws IOException { + private long[] mergeSpillsWithFileStream( + SpillInfo[] spills, + File outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; - final FileInputStream[] spillInputStreams = new FileInputStream[spills.length]; - FileOutputStream mergedFileOutputStream = null; + final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + OutputStream mergedFileOutputStream = null; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new FileInputStream(spills[i].file); } - mergedFileOutputStream = new FileOutputStream(outputFile); - for (int partition = 0; partition < numPartitions; partition++) { + final long initialFileLength = outputFile.length(); + mergedFileOutputStream = new FileOutputStream(outputFile, true); + if (compressionCodec != null) { + mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + } + for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - final FileInputStream spillInputStream = spillInputStreams[i]; - ByteStreams.copy - (new LimitedInputStream(spillInputStream, partitionLengthInSpill), - mergedFileOutputStream); - partitionLengths[partition] += partitionLengthInSpill; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + } } + mergedFileOutputStream.flush(); + mergedFileOutputStream.close(); + partitionLengths[partition] = (outputFile.length() - initialFileLength); } } finally { for (InputStream stream : spillInputStreams) { diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 48ba85f917b87..511fdfa43d543 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -37,11 +37,14 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static org.mockito.AdditionalAnswers.returnsFirstArg; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; import org.apache.spark.*; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZ4CompressionCodec; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.io.SnappyCompressionCodec; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.network.util.LimitedInputStream; @@ -65,6 +68,7 @@ public class UnsafeShuffleWriterSuite { File tempDir; long[] partitionSizesInMergedFile; final LinkedList spillFilesCreated = new LinkedList(); + SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @@ -74,10 +78,14 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; - private static final class CompressStream extends AbstractFunction1 { + private final class CompressStream extends AbstractFunction1 { @Override public OutputStream apply(OutputStream stream) { - return stream; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); + } else { + return stream; + } } } @@ -98,6 +106,7 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); + conf = new SparkConf(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); @@ -123,8 +132,35 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( + new Answer() { + @Override + public InputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + InputStream is = (InputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); + } else { + return is; + } + } + } + ); + + when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( + new Answer() { + @Override + public OutputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + OutputStream os = (OutputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); + } else { + return os; + } + } + } + ); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(new Answer() { @@ -136,11 +172,11 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer>() { + new Answer>() { @Override - public Tuple2 answer( + public Tuple2 answer( InvocationOnMock invocationOnMock) throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); File file = File.createTempFile("spillFile", ".spill", tempDir); spillFilesCreated.add(file); return Tuple2$.MODULE$.apply(blockId, file); @@ -154,7 +190,6 @@ public Tuple2 answer( } private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { - SparkConf conf = new SparkConf(); conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter( blockManager, @@ -164,7 +199,7 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabl new UnsafeShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, - new SparkConf() + conf ); } @@ -183,8 +218,11 @@ private List> readRecordsFromFile() throws IOException { if (partitionSize > 0) { InputStream in = new FileInputStream(mergedOutputFile); ByteStreams.skipFully(in, startOffset); - DeserializationStream recordsStream = serializer.newInstance().deserializeStream( - new LimitedInputStream(in, partitionSize)); + in = new LimitedInputStream(in, partitionSize); + if (conf.getBoolean("spark.shuffle.compress", true)) { + in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); + } + DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); Iterator> records = recordsStream.asKeyValueIterator(); while (records.hasNext()) { Tuple2 record = records.next(); @@ -245,7 +283,15 @@ public void writeWithoutSpilling() throws Exception { assertSpillFilesWereCleanedUp(); } - private void testMergingSpills(boolean transferToEnabled) throws IOException { + private void testMergingSpills( + boolean transferToEnabled, + String compressionCodecName) throws IOException { + if (compressionCodecName != null) { + conf.set("spark.shuffle.compress", "true"); + conf.set("spark.io.compression.codec", compressionCodecName); + } else { + conf.set("spark.shuffle.compress", "false"); + } final UnsafeShuffleWriter writer = createWriter(transferToEnabled); final ArrayList> dataToWrite = new ArrayList>(); @@ -265,11 +311,13 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException { Assert.assertTrue(mergedOutputFile.exists()); Assert.assertEquals(2, spillFilesCreated.size()); - long sumOfPartitionSizes = 0; - for (long size: partitionSizesInMergedFile) { - sumOfPartitionSizes += size; - } - Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + // This assertion only holds for the fast merging path: + // long sumOfPartitionSizes = 0; + // for (long size: partitionSizesInMergedFile) { + // sumOfPartitionSizes += size; + // } + // Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + Assert.assertTrue(mergedOutputFile.length() > 0); Assert.assertEquals( HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); @@ -277,13 +325,43 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException { } @Test - public void mergeSpillsWithTransferTo() throws Exception { - testMergingSpills(true); + public void mergeSpillsWithTransferToAndLZF() throws Exception { + testMergingSpills(true, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZF() throws Exception { + testMergingSpills(false, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndLZ4() throws Exception { + testMergingSpills(true, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZ4() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndSnappy() throws Exception { + testMergingSpills(true, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndSnappy() throws Exception { + testMergingSpills(false, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndNoCompression() throws Exception { + testMergingSpills(true, null); } @Test - public void mergeSpillsWithFileStream() throws Exception { - testMergingSpills(false); + public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { + testMergingSpills(false, null); } @Test From b3b1924e852757c93343e61c6d4cc65747697174 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 21:53:02 -0700 Subject: [PATCH 59/92] Properly implement close() and flush() in DummySerializerInstance. It turns out that we actually rely on these flushing the underlying stream in order to properly close streams in DiskBlockObjectWriter; it was silly of me to not implement these methods. This should fix a failing LZ4 test in UnsafeShuffleWriterSuite. --- .../unsafe/DummySerializerInstance.java | 33 ++++++++++++++----- .../unsafe/UnsafeShuffleWriterSuite.java | 13 ++++---- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java index 1d31a46993a22..3f746b886bc9b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java @@ -17,15 +17,18 @@ package org.apache.spark.shuffle.unsafe; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; -import scala.reflect.ClassTag; - +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import scala.reflect.ClassTag; + +import org.apache.spark.serializer.DeserializationStream; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.unsafe.PlatformDependent; + /** * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. * Our shuffle write path doesn't actually use this serializer (since we end up calling the @@ -39,10 +42,17 @@ final class DummySerializerInstance extends SerializerInstance { private DummySerializerInstance() { } @Override - public SerializationStream serializeStream(OutputStream s) { + public SerializationStream serializeStream(final OutputStream s) { return new SerializationStream() { @Override - public void flush() { } + public void flush() { + // Need to implement this because DiskObjectWriter uses it to flush the compression stream + try { + s.flush(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } @Override public SerializationStream writeObject(T t, ClassTag ev1) { @@ -50,7 +60,14 @@ public SerializationStream writeObject(T t, ClassTag ev1) { } @Override - public void close() { } + public void close() { + // Need to implement this because DiskObjectWriter uses it to close the compression stream + try { + s.close(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } }; } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 511fdfa43d543..01bf7a5095970 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -311,13 +311,12 @@ private void testMergingSpills( Assert.assertTrue(mergedOutputFile.exists()); Assert.assertEquals(2, spillFilesCreated.size()); - // This assertion only holds for the fast merging path: - // long sumOfPartitionSizes = 0; - // for (long size: partitionSizesInMergedFile) { - // sumOfPartitionSizes += size; - // } - // Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); - Assert.assertTrue(mergedOutputFile.length() > 0); + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + Assert.assertEquals( HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); From 0d4d199ca126c0b7937438b0293206c66b53badb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 May 2015 22:06:23 -0700 Subject: [PATCH 60/92] Bump up shuffle.memoryFraction to make tests pass. We'll want to revisit this before merging, since the large minimum memory usage means that minimum memory requirements for shuffle may be fairly high for local tests. --- .../spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java | 2 +- .../org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 3cf99307c47cc..772ed688359dc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -337,7 +337,7 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); if (memoryAcquiredAfterSpilling != PAGE_SIZE) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Can't allocate memory!"); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); } } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala index 8ff3abefea897..f7eefa2a3f40c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -26,5 +26,8 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { conf.set("spark.shuffle.manager", "unsafe") + // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort + // shuffle records. + conf.set("spark.shuffle.memoryFraction", "0.5") } } From ec6d62613144b7c7cbcc08fd9eb6fecd341b303d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 11:20:07 -0700 Subject: [PATCH 61/92] Add notes on maximum # of supported shuffle partitions. --- .../spark/shuffle/unsafe/PackedRecordPointer.java | 3 +++ .../spark/shuffle/unsafe/UnsafeShuffleSorter.java | 15 +++++++++++++-- .../spark/shuffle/unsafe/UnsafeShuffleWriter.java | 6 ++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java index 8c0940d23420b..ee991ee26f7a0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -37,6 +37,8 @@ final class PackedRecordPointer { static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + static final int MAXIMUM_PARTITION_ID = 1 << 24; // 16777216 + /** Bit mask for the lower 40 bits of a long. */ private static final long MASK_LONG_LOWER_40_BITS = 0xFFFFFFFFFFL; @@ -62,6 +64,7 @@ final class PackedRecordPointer { * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class. */ public static long packPointer(long recordPointer, int partitionId) { + assert (partitionId <= MAXIMUM_PARTITION_ID); // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. final int pageNumber = (int) ((recordPointer & MASK_LONG_UPPER_13_BITS) >>> 51); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java index 5acbc6c1c4f2f..8e66fbaf4c645 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java @@ -17,8 +17,10 @@ package org.apache.spark.shuffle.unsafe; +import java.io.IOException; import java.util.Comparator; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; final class UnsafeShuffleSorter { @@ -59,8 +61,17 @@ public long getMemoryUsage() { return sortBuffer.length * 8L; } - // TODO: clairify assumption that pointer points to record length. - public void insertRecord(long recordPointer, int partitionId) { + /** + * Inserts a record to be sorted. + * + * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to + * certain pointer compression techniques used by the sorter, the sort can + * only operate on pointers that point to locations in the first + * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page. + * @param partitionId the partition id, which must be less than or equal to + * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}. + */ + public void insertRecord(long recordPointer, int partitionId) throws IOException { if (!hasSpaceForAnotherRecord()) { expandSortBuffer(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index df05a95506f4b..438852cd1408c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -100,6 +100,12 @@ public UnsafeShuffleWriter( int mapId, TaskContext taskContext, SparkConf sparkConf) { + final int numPartitions = handle.dependency().partitioner().numPartitions(); + if (numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) { + throw new IllegalArgumentException( + "UnsafeShuffleWriter can only be used for shuffles with at most " + + PackedRecordPointer.MAXIMUM_PARTITION_ID + " reduce partitions"); + } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; From ae538dc34da99dc35127484fb480264c839f9cd5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 13:57:00 -0700 Subject: [PATCH 62/92] Document UnsafeShuffleManager. --- .../shuffle/unsafe/UnsafeShuffleManager.scala | 57 ++++++++ .../unsafe/UnsafeShuffleManagerSuite.scala | 128 ++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 994a8c049a331..5641903958fb5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -22,6 +22,9 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. + */ private class UnsafeShuffleHandle[K, V]( shuffleId: Int, override val numMaps: Int, @@ -30,6 +33,10 @@ private class UnsafeShuffleHandle[K, V]( } private[spark] object UnsafeShuffleManager extends Logging { + /** + * Helper method for determining whether a shuffle should use the optimized unsafe shuffle + * path or whether it should fall back to the original sort-based shuffle. + */ def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { val shufId = dependency.shuffleId val serializer = Serializer.getSerializer(dependency.serializer) @@ -43,6 +50,10 @@ private[spark] object UnsafeShuffleManager extends Logging { } else if (dependency.keyOrdering.isDefined) { log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined") false + } else if (dependency.partitioner.numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + + s"${PackedRecordPointer.MAXIMUM_PARTITION_ID} partitions") + false } else { log.debug(s"Can use UnsafeShuffle for shuffle $shufId") true @@ -50,6 +61,52 @@ private[spark] object UnsafeShuffleManager extends Logging { } } +/** + * A shuffle implementation that uses directly-managed memory to implement several performance + * optimizations for certain types of shuffles. In cases where the new performance optimizations + * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those + * shuffles. + * + * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: + * + * - The shuffle dependency specifies no aggregation or output ordering. + * - The shuffle serializer supports relocation of serialized values (this is currently supported + * by KryoSerializer and Spark SQL's custom serializers). + * - The shuffle produces fewer than 16777216 output partitions. + * - No individual record is larger than 128 MB when serialized. + * + * In addition, extra spill-merging optimizations are automatically applied when the shuffle + * compression codec supports concatenation of serialized streams. This is currently supported by + * Spark's LZF serializer. + * + * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * UnsafeShuffleManager optimizes this process in several ways: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on UnsafeShuffleManager's design, see SPARK-7081. + */ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager { private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala new file mode 100644 index 0000000000000..9c91948bdc1e4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark._ +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} + +/** + * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are + * performed in other suites. + */ +class UnsafeShuffleManagerSuite extends FunSuite with Matchers { + + import UnsafeShuffleManager.canUseUnsafeShuffle + + private class RuntimeExceptionAnswer extends Answer[Object] { + override def answer(invocation: InvocationOnMock): Object = { + throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName) + } + } + + private def shuffleDep( + partitioner: Partitioner, + serializer: Option[Serializer], + keyOrdering: Option[Ordering[Any]], + aggregator: Option[Aggregator[Any, Any, Any]], + mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = { + val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer()) + doReturn(0).when(dep).shuffleId + doReturn(partitioner).when(dep).partitioner + doReturn(serializer).when(dep).serializer + doReturn(keyOrdering).when(dep).keyOrdering + doReturn(aggregator).when(dep).aggregator + doReturn(mapSideCombine).when(dep).mapSideCombine + dep + } + + test("supported shuffle dependencies") { + val kryo = Some(new KryoSerializer(new SparkConf())) + + assert(canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]]) + when(rangePartitioner.numPartitions).thenReturn(2) + assert(canUseUnsafeShuffle(shuffleDep( + partitioner = rangePartitioner, + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + } + + test("unsupported shuffle dependencies") { + val kryo = Some(new KryoSerializer(new SparkConf())) + val java = Some(new JavaSerializer(new SparkConf())) + + // We only support serializers that support object relocation + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = java, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + // We do not support shuffles with more than 16 million output partitions + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1), + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + // We do not support shuffles that perform any kind of aggregation or sorting of keys + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = Some(mock(classOf[Ordering[Any]])), + aggregator = None, + mapSideCombine = false + ))) + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) + // We do not support shuffles that perform any kind of aggregation or sorting of keys + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = Some(mock(classOf[Ordering[Any]])), + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = true + ))) + } + +} From ea4f85fede7966ac26ed5ce2f3995f74bd742930 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 14:04:33 -0700 Subject: [PATCH 63/92] Roll back an unnecessary change in Spillable. --- .../apache/spark/util/collection/Spillable.scala | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 841a4cd791c4c..747ecf075a397 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -20,20 +20,11 @@ package org.apache.spark.util.collection import org.apache.spark.Logging import org.apache.spark.SparkEnv -private[spark] object Spillable { - // Initial threshold for the size of a collection before we start tracking its memory usage - val initialMemoryThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) -} - /** * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ private[spark] trait Spillable[C] extends Logging { - - import Spillable._ - /** * Spills the current in-memory collection to disk, and releases the memory. * @@ -51,6 +42,11 @@ private[spark] trait Spillable[C] extends Logging { // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + // Initial threshold for the size of a collection before we start tracking its memory usage + // Exposed for testing + private[this] val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) + // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold From 1e3ad52a0a40bef7426846ab1e3c1b291516cb4b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 14:32:37 -0700 Subject: [PATCH 64/92] Delete unused ByteBufferOutputStream class. --- .../unsafe/ByteBufferOutputStream.java | 46 ------------------- 1 file changed, 46 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/ByteBufferOutputStream.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/ByteBufferOutputStream.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/ByteBufferOutputStream.java deleted file mode 100644 index 3410cd2911ebe..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/ByteBufferOutputStream.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe; - -import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; - -class ByteBufferOutputStream extends OutputStream { - - private final ByteBuffer byteBuffer; - - public ByteBufferOutputStream(ByteBuffer byteBuffer) { - this.byteBuffer = byteBuffer; - } - - @Override - public void write(int b) throws IOException { - byteBuffer.put((byte) b); - } - - @Override - public void write(byte[] b) throws IOException { - byteBuffer.put(b); - } - - @Override - public void write(byte[] b, int off, int len) throws IOException { - byteBuffer.put(b, off, len); - } -} From 39434f9afcb9639dec6fd17d9348b4551bdb16b1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 14:58:25 -0700 Subject: [PATCH 65/92] Avoid integer multiplication overflow in getMemoryUsage (thanks FindBugs!) --- .../spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 772ed688359dc..44a37fcd43951 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -246,7 +246,7 @@ void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * PAGE_SIZE); + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); } private long freeMemory() { From e1855e556321b1d119fd361ff436fe1e91ddedb7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 15:39:29 -0700 Subject: [PATCH 66/92] Fix a handful of misc. IntelliJ inspections --- .../spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java | 3 +-- .../org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java | 4 +--- .../org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 44a37fcd43951..1d1382c104fea 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -139,7 +139,7 @@ private SpillInfo writeSpillFile() throws IOException { // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. - BlockObjectWriter writer = null; + BlockObjectWriter writer; // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer @@ -202,7 +202,6 @@ private SpillInfo writeSpillFile() throws IOException { writeBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, toTransfer); - assert (writer != null); // To suppress an IntelliJ warning writer.write(writeBuffer, 0, toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java index 8e66fbaf4c645..f2b90617793e5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java @@ -17,10 +17,8 @@ package org.apache.spark.shuffle.unsafe; -import java.io.IOException; import java.util.Comparator; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; final class UnsafeShuffleSorter { @@ -71,7 +69,7 @@ public long getMemoryUsage() { * @param partitionId the partition id, which must be less than or equal to * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}. */ - public void insertRecord(long recordPointer, int partitionId) throws IOException { + public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { expandSortBuffer(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 438852cd1408c..8977517c0bcbe 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -135,7 +135,7 @@ public void write(scala.collection.Iterator> records) throws IOEx closeAndWriteOutput(); } catch (Exception e) { // Unfortunately, we have to catch Exception here in order to ensure proper cleanup after - // errors becuase Spark's Scala code, or users' custom Serializers, might throw arbitrary + // errors because Spark's Scala code, or users' custom Serializers, might throw arbitrary // unchecked exceptions. try { sorter.cleanupAfterError(); From 7c953f917140986a5a9a1ced95ef85d9f905b378 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 16:11:25 -0700 Subject: [PATCH 67/92] Add test that covers UnsafeShuffleSortDataFormat.swap(). --- .../unsafe/UnsafeShuffleSorterSuite.java | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java index 080145b90554a..3fc73b04888ee 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.unsafe; import java.util.Arrays; +import java.util.Random; import org.junit.Assert; import org.junit.Test; @@ -107,4 +108,25 @@ public void testBasicSorting() throws Exception { } Assert.assertFalse(iter.hasNext()); } + + @Test + public void testSortingManyNumbers() throws Exception { + UnsafeShuffleSorter sorter = new UnsafeShuffleSorter(4); + int[] numbersToSort = new int[128000]; + Random random = new Random(16); + for (int i = 0; i < numbersToSort.length; i++) { + numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID); + sorter.insertRecord(0, numbersToSort[i]); + } + Arrays.sort(numbersToSort); + int[] sorterResult = new int[numbersToSort.length]; + UnsafeShuffleSorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + int j = 0; + while (iter.hasNext()) { + iter.loadNext(); + sorterResult[j] = iter.packedRecordPointer.getPartitionId(); + j += 1; + } + Assert.assertArrayEquals(numbersToSort, sorterResult); + } } From 853128684abc302a965ade6dba8df09b08345c5d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 16:42:11 -0700 Subject: [PATCH 68/92] Add tests that automatically trigger spills. This bumps up line coverage to 93% in UnsafeShuffleExternalSorter; now, the only branches that are missed are exception-handling code. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 5 ++- .../unsafe/UnsafeShuffleWriterSuite.java | 41 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 8977517c0bcbe..02bf7e321df12 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -59,6 +59,9 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + @VisibleForTesting + static final int INITIAL_SORT_BUFFER_SIZE = 4096; + private final BlockManager blockManager; private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; @@ -152,7 +155,7 @@ private void open() throws IOException { shuffleMemoryManager, blockManager, taskContext, - 4096, // Initial size (TODO: tune this!) + INITIAL_SORT_BUFFER_SIZE, partitioner.numPartitions(), sparkConf); serBuffer = new MyByteArrayOutputStream(1024 * 1024); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 01bf7a5095970..c53e0fcf44880 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -363,6 +363,47 @@ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { testMergingSpills(false, null); } + @Test + public void writeEnoughDataToTriggerSpill() throws Exception { + when(shuffleMemoryManager.tryToAcquire(anyLong())) + .then(returnsFirstArg()) // Allocate initial sort buffer + .then(returnsFirstArg()) // Allocate initial data page + .thenReturn(0L) // Deny request to allocate new data page + .then(returnsFirstArg()); // Grant new sort buffer and data page. + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList>(); + final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128]; + for (int i = 0; i < 128 + 1; i++) { + dataToWrite.add(new Tuple2(i, bigByteArray)); + } + writer.write(dataToWrite.iterator()); + verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + Assert.assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { + when(shuffleMemoryManager.tryToAcquire(anyLong())) + .then(returnsFirstArg()) // Allocate initial sort buffer + .then(returnsFirstArg()) // Allocate initial data page + .thenReturn(0L) // Deny request to grow sort buffer + .then(returnsFirstArg()); // Grant new sort buffer and data page. + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList>(); + for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { + dataToWrite.add(new Tuple2(i, i)); + } + writer.write(dataToWrite.iterator()); + verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + Assert.assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + } + @Test public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { final UnsafeShuffleWriter writer = createWriter(false); From 69d58992a16c9368eee31a79ecc321cec47b6801 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 16:51:42 -0700 Subject: [PATCH 69/92] Remove some unnecessary override vals --- .../apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 5641903958fb5..5db64e2144abe 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -27,8 +27,8 @@ import org.apache.spark.shuffle.sort.SortShuffleManager */ private class UnsafeShuffleHandle[K, V]( shuffleId: Int, - override val numMaps: Int, - override val dependency: ShuffleDependency[K, V, V]) + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) extends BaseShuffleHandle(shuffleId, numMaps, dependency) { } From d4e6d89152685b2821f158e101b99ea35dba2ac3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 16:55:20 -0700 Subject: [PATCH 70/92] Update to bit shifting constants --- .../spark/shuffle/unsafe/PackedRecordPointer.java | 10 ++++------ .../org/apache/spark/shuffle/unsafe/SpillInfo.java | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java index ee991ee26f7a0..35d7b7b6651d0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -37,25 +37,23 @@ final class PackedRecordPointer { static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes - static final int MAXIMUM_PARTITION_ID = 1 << 24; // 16777216 + static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 /** Bit mask for the lower 40 bits of a long. */ - private static final long MASK_LONG_LOWER_40_BITS = 0xFFFFFFFFFFL; + private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1; /** Bit mask for the upper 24 bits of a long */ private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS; /** Bit mask for the lower 27 bits of a long. */ - private static final long MASK_LONG_LOWER_27_BITS = 0x7FFFFFFL; + private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1; /** Bit mask for the lower 51 bits of a long. */ - private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1; /** Bit mask for the upper 13 bits of a long */ private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; - // TODO: this shifting is probably extremely inefficient; this is just for prototyping - /** * Pack a record address and partition id into a single word. * diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java index 5d13354231491..7bac0dc0bbeb6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -17,10 +17,10 @@ package org.apache.spark.shuffle.unsafe; -import org.apache.spark.storage.TempShuffleBlockId; - import java.io.File; +import org.apache.spark.storage.TempShuffleBlockId; + /** * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. */ From 4f0b770829698c319eee24ecdc959ab3fd1cdb29 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 May 2015 20:07:17 -0700 Subject: [PATCH 71/92] Attempt to implement proper shuffle write metrics. --- core/pom.xml | 10 +++ .../unsafe/TimeTrackingOutputStream.java | 70 +++++++++++++++++ .../unsafe/UnsafeShuffleExternalSorter.java | 54 ++++++++----- .../shuffle/unsafe/UnsafeShuffleWriter.java | 44 ++++++++--- .../unsafe/UnsafeShuffleWriterSuite.java | 76 +++++++++++++------ pom.xml | 12 +++ 6 files changed, 214 insertions(+), 52 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java diff --git a/core/pom.xml b/core/pom.xml index fc42f48973fe9..e36d1a45aa9a2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -361,6 +361,16 @@ junit test + + org.hamcrest + hamcrest-core + test + + + org.hamcrest + hamcrest-library + test + com.novocode junit-interface diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java new file mode 100644 index 0000000000000..8b5ba49e67204 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.apache.spark.executor.ShuffleWriteMetrics; + +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Intercepts write calls and tracks total time spent writing. + */ +final class TimeTrackingFileOutputStream extends OutputStream { + + private final ShuffleWriteMetrics writeMetrics; + private final FileOutputStream outputStream; + + public TimeTrackingFileOutputStream( + ShuffleWriteMetrics writeMetrics, + FileOutputStream outputStream) { + this.writeMetrics = writeMetrics; + this.outputStream = outputStream; + } + + @Override + public void write(int b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b, off, len); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void flush() throws IOException { + outputStream.flush(); + } + + @Override + public void close() throws IOException { + outputStream.close(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 1d1382c104fea..c4d26288de33d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -70,6 +70,7 @@ final class UnsafeShuffleExternalSorter { private final BlockManager blockManager; private final TaskContext taskContext; private final boolean spillingEnabled; + private final ShuffleWriteMetrics writeMetrics; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSize; @@ -97,7 +98,8 @@ public UnsafeShuffleExternalSorter( TaskContext taskContext, int initialSize, int numPartitions, - SparkConf conf) throws IOException { + SparkConf conf, + ShuffleWriteMetrics writeMetrics) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; @@ -107,6 +109,7 @@ public UnsafeShuffleExternalSorter( this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true); // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.writeMetrics = writeMetrics; openSorter(); } @@ -131,8 +134,24 @@ private void openSorter() throws IOException { /** * Sorts the in-memory records and writes the sorted records to a spill file. * This method does not free the sort data structures. + * + * @param isSpill if true, this indicates that we're writing a spill and that bytes written should + * be counted towards shuffle spill metrics rather than shuffle write metrics. */ - private SpillInfo writeSpillFile() throws IOException { + private void writeSpillFile(boolean isSpill) throws IOException { + + final ShuffleWriteMetrics writeMetricsToUse; + + if (isSpill) { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } else { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } + // This call performs the actual sort. final UnsafeShuffleSorter.UnsafeShuffleSorterIterator sortedRecords = sorter.getSortedIterator(); @@ -161,17 +180,8 @@ private SpillInfo writeSpillFile() throws IOException { // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work // around this, we pass a dummy no-op serializer. final SerializerInstance ser = DummySerializerInstance.INSTANCE; - // TODO: audit the metrics-related code and ensure proper metrics integration: - // It's not clear how we should handle shuffle write metrics for spill files; currently, Spark - // doesn't report IO time spent writing spill files (see SPARK-7413). This method, - // writeSpillFile(), is called both when writing spill files and when writing the single output - // file in cases where we didn't spill. As a result, we don't necessarily know whether this - // should be reported as bytes spilled or as shuffle bytes written. We could defer the updating - // of these metrics until the end of the shuffle write, but that would mean that that users - // wouldn't get useful metrics updates in the UI from long-running tasks. Given this complexity, - // I'm deferring these decisions to a separate follow-up commit or patch. - writer = - blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, new ShuffleWriteMetrics()); + + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetricsToUse); int currentPartition = -1; while (sortedRecords.hasNext()) { @@ -185,8 +195,7 @@ private SpillInfo writeSpillFile() throws IOException { spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); } currentPartition = partition; - writer = - blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, new ShuffleWriteMetrics()); + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetricsToUse); } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); @@ -220,7 +229,14 @@ private SpillInfo writeSpillFile() throws IOException { spills.add(spillInfo); } } - return spillInfo; + + if (isSpill) { + writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); + // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. + // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. + // writeMetrics.incShuffleWriteTime(writeMetricsToUse.shuffleWriteTime()); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); + } } /** @@ -233,13 +249,12 @@ void spill() throws IOException { org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" + (spills.size() + (spills.size() > 1 ? " times" : " time")) + " so far)"); - final SpillInfo spillInfo = writeSpillFile(); + writeSpillFile(true); final long sorterMemoryUsage = sorter.getMemoryUsage(); sorter = null; shuffleMemoryManager.release(sorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - taskContext.taskMetrics().incDiskBytesSpilled(spillInfo.file.length()); openSorter(); } @@ -389,7 +404,8 @@ public void insertRecord( public SpillInfo[] closeAndGetSpills() throws IOException { try { if (sorter != null) { - writeSpillFile(); + // Do not count the final file towards the spill count. + writeSpillFile(false); freeMemory(); } return spills.toArray(new SpillInfo[spills.size()]); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 02bf7e321df12..7544ebbfeaad5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -157,7 +157,8 @@ private void open() throws IOException { taskContext, INITIAL_SORT_BUFFER_SIZE, partitioner.numPartitions(), - sparkConf); + sparkConf, + writeMetrics); serBuffer = new MyByteArrayOutputStream(1024 * 1024); serOutputStream = serializer.serializeStream(serBuffer); } @@ -210,6 +211,12 @@ void forceSorterToSpill() throws IOException { sorter.spill(); } + /** + * Merge zero or more spill files together, choosing the fastest merging strategy based on the + * number of spills and the IO compression codec. + * + * @return the partition lengths in the merged file. + */ private long[] mergeSpills(SpillInfo[] spills) throws IOException { final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); @@ -223,30 +230,42 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { new FileOutputStream(outputFile).close(); // Create an empty file return new long[partitioner.numPartitions()]; } else if (spills.length == 1) { - // Note: we'll have to watch out for corner-cases in this code path when working on shuffle - // metrics integration, since any metrics updates that are performed during the merge will - // also have to be done here. In this branch, the shuffle technically didn't need to spill - // because we're only trying to merge one file, so we may need to ensure that metrics that - // would otherwise be counted as spill metrics are actually counted as regular write - // metrics. + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. Files.move(spills[0].file, outputFile); return spills[0].partitionLengths; } else { + final long[] partitionLengths; + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" + // branch in ExternalSorter. if (fastMergeEnabled && fastMergeIsSupported) { // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. if (transferToEnabled) { logger.debug("Using transferTo-based fast merge"); - return mergeSpillsWithTransferTo(spills, outputFile); + partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { logger.debug("Using fileStream-based fast merge"); - return mergeSpillsWithFileStream(spills, outputFile, null); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); } } else { logger.debug("Using slow merge"); - return mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); } + // The final shuffle spill's write would have directly updated shuffleBytesWritten, so + // we need to decrement to avoid double-counting this write. + writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incShuffleBytesWritten(outputFile.length()); + return partitionLengths; } } catch (IOException e) { if (outputFile.exists() && !outputFile.delete()) { @@ -271,7 +290,8 @@ private long[] mergeSpillsWithFileStream( } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = outputFile.length(); - mergedFileOutputStream = new FileOutputStream(outputFile, true); + mergedFileOutputStream = + new TimeTrackingFileOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); if (compressionCodec != null) { mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); } @@ -321,6 +341,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th final long partitionLengthInSpill = spills[i].partitionLengths[partition]; long bytesToTransfer = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); while (bytesToTransfer > 0) { final long actualBytesTransferred = spillInputChannel.transferTo( spillInputChannelPositions[i], @@ -329,6 +350,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th spillInputChannelPositions[i] += actualBytesTransferred; bytesToTransfer -= actualBytesTransferred; } + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index c53e0fcf44880..8451e8d9a9785 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -29,13 +29,16 @@ import com.google.common.collect.HashMultiset; import com.google.common.io.ByteStreams; import org.junit.After; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import static org.junit.Assert.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; import static org.mockito.AdditionalAnswers.returnsFirstArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -70,6 +73,7 @@ public class UnsafeShuffleWriterSuite { final LinkedList spillFilesCreated = new LinkedList(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); + TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -94,7 +98,7 @@ public void tearDown() { Utils.deleteRecursively(tempDir); final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); if (leakedMemory != 0) { - Assert.fail("Test leaked " + leakedMemory + " bytes of managed memory"); + fail("Test leaked " + leakedMemory + " bytes of managed memory"); } } @@ -107,6 +111,7 @@ public void setUp() throws IOException { partitionSizesInMergedFile = null; spillFilesCreated.clear(); conf = new SparkConf(); + taskMetrics = new TaskMetrics(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); @@ -183,7 +188,7 @@ public Tuple2 answer( } }); - when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); @@ -205,7 +210,7 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabl private void assertSpillFilesWereCleanedUp() { for (File spillFile : spillFilesCreated) { - Assert.assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", spillFile.exists()); } } @@ -226,7 +231,7 @@ private List> readRecordsFromFile() throws IOException { Iterator> records = recordsStream.asKeyValueIterator(); while (records.hasNext()) { Tuple2 record = records.next(); - Assert.assertEquals(i, hashPartitioner.getPartition(record._1())); + assertEquals(i, hashPartitioner.getPartition(record._1())); recordsList.add(record); } recordsStream.close(); @@ -251,9 +256,13 @@ public void writeEmptyIterator() throws Exception { final UnsafeShuffleWriter writer = createWriter(true); writer.write(Collections.>emptyIterator()); final Option mapStatus = writer.stop(true); - Assert.assertTrue(mapStatus.isDefined()); - Assert.assertTrue(mergedOutputFile.exists()); - Assert.assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); } @Test @@ -267,20 +276,25 @@ public void writeWithoutSpilling() throws Exception { final UnsafeShuffleWriter writer = createWriter(true); writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); - Assert.assertTrue(mapStatus.isDefined()); - Assert.assertTrue(mergedOutputFile.exists()); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; for (long size: partitionSizesInMergedFile) { // All partitions should be the same size: - Assert.assertEquals(partitionSizesInMergedFile[0], size); + assertEquals(partitionSizesInMergedFile[0], size); sumOfPartitionSizes += size; } - Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); - Assert.assertEquals( + assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + assertEquals( HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); } private void testMergingSpills( @@ -307,20 +321,26 @@ private void testMergingSpills( writer.insertRecordIntoSorter(dataToWrite.get(5)); writer.closeAndWriteOutput(); final Option mapStatus = writer.stop(true); - Assert.assertTrue(mapStatus.isDefined()); - Assert.assertTrue(mergedOutputFile.exists()); - Assert.assertEquals(2, spillFilesCreated.size()); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertEquals(2, spillFilesCreated.size()); long sumOfPartitionSizes = 0; for (long size: partitionSizesInMergedFile) { sumOfPartitionSizes += size; } - Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); - Assert.assertEquals( + assertEquals( HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); } @Test @@ -378,10 +398,16 @@ public void writeEnoughDataToTriggerSpill() throws Exception { } writer.write(dataToWrite.iterator()); verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); - Assert.assertEquals(2, spillFilesCreated.size()); + assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); } @Test @@ -398,10 +424,16 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exce } writer.write(dataToWrite.iterator()); verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); - Assert.assertEquals(2, spillFilesCreated.size()); + assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); } @Test @@ -414,7 +446,7 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); writer.write(dataToWrite.iterator()); writer.stop(true); - Assert.assertEquals( + assertEquals( HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); @@ -474,7 +506,7 @@ public void close() { } // Here, we write through the public `write()` interface instead of the test-only // `insertRecordIntoSorter` interface: writer.write(Collections.singletonList(hugeRecord).iterator()); - Assert.fail("Expected exception to be thrown"); + fail("Expected exception to be thrown"); } catch (IOException e) { // Pass } diff --git a/pom.xml b/pom.xml index c83dec4f42399..8985c2988783d 100644 --- a/pom.xml +++ b/pom.xml @@ -684,6 +684,18 @@ 4.10 test + + org.hamcrest + hamcrest-core + 1.3 + test + + + org.hamcrest + hamcrest-library + 1.3 + test + com.novocode junit-interface From e58a6b407fd264d36f527b44ea47aff8f5ff5568 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 11:34:29 -0700 Subject: [PATCH 72/92] Add more tests for PackedRecordPointer encoding. --- .../unsafe/PackedRecordPointerSuite.java | 46 +++++++++++++++++-- .../unsafe/UnsafeShuffleWriterSuite.java | 2 +- unsafe/pom.xml | 4 ++ .../unsafe/memory/TaskMemoryManager.java | 10 +++- 4 files changed, 54 insertions(+), 8 deletions(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java index ba1f89d099838..4fda87ab57c49 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -17,13 +17,14 @@ package org.apache.spark.shuffle.unsafe; -import org.junit.Assert; import org.junit.Test; +import static org.junit.Assert.*; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*; public class PackedRecordPointerSuite { @@ -36,8 +37,8 @@ public void heap() { final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); - Assert.assertEquals(360, packedPointer.getPartitionId()); - Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer()); + assertEquals(360, packedPointer.getPartitionId()); + assertEquals(addressInPage1, packedPointer.getRecordPointer()); memoryManager.cleanUpAllAllocatedMemory(); } @@ -50,8 +51,43 @@ public void offHeap() { final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); - Assert.assertEquals(360, packedPointer.getPartitionId()); - Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer()); + assertEquals(360, packedPointer.getPartitionId()); + assertEquals(addressInPage1, packedPointer.getRecordPointer()); memoryManager.cleanUpAllAllocatedMemory(); } + + @Test + public void maximumPartitionIdCanBeEncoded() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID)); + assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId()); + } + + @Test + public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + try { + // Pointers greater than the maximum partition ID will overflow or trigger an assertion error + packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1)); + assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId()); + } catch (AssertionError e ) { + // pass + } + } + + @Test + public void maximumOffsetInPageCanBeEncoded() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1); + packedPointer.set(PackedRecordPointer.packPointer(address, 0)); + assertEquals(address, packedPointer.getRecordPointer()); + } + + @Test + public void offsetsPastMaxOffsetInPageWillOverflow() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES); + packedPointer.set(PackedRecordPointer.packPointer(address, 0)); + assertEquals(0, packedPointer.getRecordPointer()); + } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 8451e8d9a9785..61511de6a5219 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -35,10 +35,10 @@ import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import static org.junit.Assert.*; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.*; import static org.mockito.AdditionalAnswers.returnsFirstArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 5b0733206b2bc..9e151fc7a9141 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -42,6 +42,10 @@ com.google.code.findbugs jsr305 + + com.google.guava + guava + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 2aacf637eb6a4..cfd54035bee99 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -19,6 +19,7 @@ import java.util.*; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -169,8 +170,13 @@ public void free(MemoryBlock memory) { * This address will remain valid as long as the corresponding page has not been freed. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); } /** From e995d1a14565bfed400ba19a92d6a3035c3cca5a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 11:51:01 -0700 Subject: [PATCH 73/92] Introduce MAX_SHUFFLE_OUTPUT_PARTITIONS. --- .../spark/shuffle/unsafe/PackedRecordPointer.java | 3 +++ .../spark/shuffle/unsafe/UnsafeShuffleManager.scala | 10 ++++++++-- .../shuffle/unsafe/UnsafeShuffleManagerSuite.scala | 2 +- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java index 35d7b7b6651d0..6d61b1b9e34da 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -37,6 +37,9 @@ final class PackedRecordPointer { static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + /** + * The maximum partition identifier that can be encoded. Note that partition ids start from 0. + */ static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 /** Bit mask for the lower 40 bits of a long. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 5db64e2144abe..4785e8c0f91a3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -33,6 +33,12 @@ private class UnsafeShuffleHandle[K, V]( } private[spark] object UnsafeShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. + */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + /** * Helper method for determining whether a shuffle should use the optimized unsafe shuffle * path or whether it should fall back to the original sort-based shuffle. @@ -50,9 +56,9 @@ private[spark] object UnsafeShuffleManager extends Logging { } else if (dependency.keyOrdering.isDefined) { log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined") false - } else if (dependency.partitioner.numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) { + } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + - s"${PackedRecordPointer.MAXIMUM_PARTITION_ID} partitions") + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") false } else { log.debug(s"Can use UnsafeShuffle for shuffle $shufId") diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala index 9c91948bdc1e4..49a04a2a45280 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala @@ -93,7 +93,7 @@ class UnsafeShuffleManagerSuite extends FunSuite with Matchers { // We do not support shuffles with more than 16 million output partitions assert(!canUseUnsafeShuffle(shuffleDep( - partitioner = new HashPartitioner(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1), + partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), serializer = kryo, keyOrdering = None, aggregator = None, From 56781a16f367ed4955006fb05900951dacc40e8a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 11:54:40 -0700 Subject: [PATCH 74/92] Rename UnsafeShuffleSorter to UnsafeShuffleInMemorySorter --- .../unsafe/UnsafeShuffleExternalSorter.java | 12 ++++++------ ...orter.java => UnsafeShuffleInMemorySorter.java} | 4 ++-- ....java => UnsafeShuffleInMemorySorterSuite.java} | 14 +++++++------- 3 files changed, 15 insertions(+), 15 deletions(-) rename core/src/main/java/org/apache/spark/shuffle/unsafe/{UnsafeShuffleSorter.java => UnsafeShuffleInMemorySorter.java} (97%) rename core/src/test/java/org/apache/spark/shuffle/unsafe/{UnsafeShuffleSorterSuite.java => UnsafeShuffleInMemorySorterSuite.java} (88%) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index c4d26288de33d..36f05e3df8753 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -42,9 +42,9 @@ *

* Incoming records are appended to data pages. When all records have been inserted (or when the * current thread's shuffle memory limit is reached), the in-memory records are sorted according to - * their partition ids (using a {@link UnsafeShuffleSorter}). The sorted records are then written - * to a single output file (or multiple files, if we've spilled). The format of the output files is - * the same as the format of the final output file written by + * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * written to a single output file (or multiple files, if we've spilled). The format of the output + * files is the same as the format of the final output file written by * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are * written as a single serialized, compressed stream that can be read with a new decompression and * deserialization stream. @@ -86,7 +86,7 @@ final class UnsafeShuffleExternalSorter { private final LinkedList spills = new LinkedList(); // All three of these variables are reset after spilling: - private UnsafeShuffleSorter sorter; + private UnsafeShuffleInMemorySorter sorter; private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; @@ -128,7 +128,7 @@ private void openSorter() throws IOException { } } - this.sorter = new UnsafeShuffleSorter(initialSize); + this.sorter = new UnsafeShuffleInMemorySorter(initialSize); } /** @@ -153,7 +153,7 @@ private void writeSpillFile(boolean isSpill) throws IOException { } // This call performs the actual sort. - final UnsafeShuffleSorter.UnsafeShuffleSorterIterator sortedRecords = + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = sorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java similarity index 97% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java rename to core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java index f2b90617793e5..6fb87848df38b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -21,7 +21,7 @@ import org.apache.spark.util.collection.Sorter; -final class UnsafeShuffleSorter { +final class UnsafeShuffleInMemorySorter { private final Sorter sorter; private static final class SortComparator implements Comparator { @@ -39,7 +39,7 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int sortBufferInsertPosition = 0; - public UnsafeShuffleSorter(int initialSize) { + public UnsafeShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); this.sortBuffer = new long[initialSize]; this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java similarity index 88% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java index 3fc73b04888ee..d1c45092693f2 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -public class UnsafeShuffleSorterSuite { +public class UnsafeShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; @@ -44,8 +44,8 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { - final UnsafeShuffleSorter sorter = new UnsafeShuffleSorter(100); - final UnsafeShuffleSorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100); + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -66,7 +66,7 @@ public void testBasicSorting() throws Exception { new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); final MemoryBlock dataPage = memoryManager.allocatePage(2048); final Object baseObject = dataPage.getBaseObject(); - final UnsafeShuffleSorter sorter = new UnsafeShuffleSorter(4); + final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter @@ -87,7 +87,7 @@ public void testBasicSorting() throws Exception { } // Sort the records - final UnsafeShuffleSorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); int prevPartitionId = -1; Arrays.sort(dataToSort); for (int i = 0; i < dataToSort.length; i++) { @@ -111,7 +111,7 @@ public void testBasicSorting() throws Exception { @Test public void testSortingManyNumbers() throws Exception { - UnsafeShuffleSorter sorter = new UnsafeShuffleSorter(4); + UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { @@ -120,7 +120,7 @@ public void testSortingManyNumbers() throws Exception { } Arrays.sort(numbersToSort); int[] sorterResult = new int[numbersToSort.length]; - UnsafeShuffleSorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); int j = 0; while (iter.hasNext()) { iter.loadNext(); From 0ad34da135f209c008d285dad0c5737c6683d25a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 12:05:59 -0700 Subject: [PATCH 75/92] Fix off-by-one in nextInt() call --- .../spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java index d1c45092693f2..8fa72597db24d 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -115,7 +115,7 @@ public void testSortingManyNumbers() throws Exception { int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { - numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID); + numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); sorter.insertRecord(0, numbersToSort[i]); } Arrays.sort(numbersToSort); From 85da63fb741b22a21b147b28aa4145e884f91a50 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 12:09:09 -0700 Subject: [PATCH 76/92] Cleanup in UnsafeShuffleSorterIterator. --- .../unsafe/UnsafeShuffleInMemorySorter.java | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java index 6fb87848df38b..57b24547125f6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -78,14 +78,29 @@ public void insertRecord(long recordPointer, int partitionId) { sortBufferInsertPosition++; } - public static abstract class UnsafeShuffleSorterIterator { + /** + * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. + */ + public static final class UnsafeShuffleSorterIterator { + private final long[] sortBuffer; + private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); + private int position = 0; - public abstract boolean hasNext(); + public UnsafeShuffleSorterIterator(int numRecords, long[] sortBuffer) { + this.numRecords = numRecords; + this.sortBuffer = sortBuffer; + } - public abstract void loadNext(); + public boolean hasNext() { + return position < numRecords; + } + public void loadNext() { + packedRecordPointer.set(sortBuffer[position]); + position++; + } } /** @@ -93,20 +108,6 @@ public static abstract class UnsafeShuffleSorterIterator { */ public UnsafeShuffleSorterIterator getSortedIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition, SORT_COMPARATOR); - return new UnsafeShuffleSorterIterator() { - - private int position = 0; - - @Override - public boolean hasNext() { - return position < sortBufferInsertPosition; - } - - @Override - public void loadNext() { - packedRecordPointer.set(sortBuffer[position]); - position++; - } - }; + return new UnsafeShuffleSorterIterator(sortBufferInsertPosition, sortBuffer); } } From fdcac0838e5f4d53c2fec0ff815f0f26a568357f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 12:14:42 -0700 Subject: [PATCH 77/92] Guard against overflow when expanding sort buffer. --- .../shuffle/unsafe/UnsafeShuffleInMemorySorter.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java index 57b24547125f6..b4055141a4ec5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -47,7 +47,9 @@ public UnsafeShuffleInMemorySorter(int initialSize) { public void expandSortBuffer() { final long[] oldBuffer = sortBuffer; - sortBuffer = new long[oldBuffer.length * 2]; + // Guard against overflow: + final int newLength = oldBuffer.length * 2 > 0 ? (oldBuffer.length * 2) : Integer.MAX_VALUE; + sortBuffer = new long[newLength]; System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); } @@ -71,7 +73,11 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - expandSortBuffer(); + if (sortBuffer.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Sort buffer has reached maximum size"); + } else { + expandSortBuffer(); + } } sortBuffer[sortBufferInsertPosition] = PackedRecordPointer.packPointer(recordPointer, partitionId); From 2d4e4f42f917aa78dacbc0d58e818f5014401e32 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 13:36:21 -0700 Subject: [PATCH 78/92] Address some minor comments in UnsafeShuffleExternalSorter. --- .../unsafe/UnsafeShuffleExternalSorter.java | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 36f05e3df8753..d4db1d298e833 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -36,6 +36,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; /** * An external sorter that is specialized for sort-based shuffle. @@ -85,7 +86,7 @@ final class UnsafeShuffleExternalSorter { private final LinkedList spills = new LinkedList(); - // All three of these variables are reset after spilling: + // These variables are reset after spilling: private UnsafeShuffleInMemorySorter sorter; private MemoryBlock currentPage = null; private long currentPagePosition = -1; @@ -110,21 +111,20 @@ public UnsafeShuffleExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.writeMetrics = writeMetrics; - openSorter(); + initializeForWriting(); } /** - * Allocates a new sorter. Called when opening the spill writer for the first time and after - * each spill. + * Allocates new sort data structures. Called when creating the sorter and after each spill. */ - private void openSorter() throws IOException { + private void initializeForWriting() throws IOException { // TODO: move this sizing calculation logic into a static method of sorter: final long memoryRequested = initialSize * 8L; if (spillingEnabled) { final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); if (memoryAcquired != memoryRequested) { shuffleMemoryManager.release(memoryAcquired); - throw new IOException("Could not acquire memory!"); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } } @@ -132,24 +132,25 @@ private void openSorter() throws IOException { } /** - * Sorts the in-memory records and writes the sorted records to a spill file. + * Sorts the in-memory records and writes the sorted records to an on-disk file. * This method does not free the sort data structures. * - * @param isSpill if true, this indicates that we're writing a spill and that bytes written should - * be counted towards shuffle spill metrics rather than shuffle write metrics. + * @param isLastFile if true, this indicates that we're writing the final output file and that the + * bytes written should be counted towards shuffle spill metrics rather than + * shuffle write metrics. */ - private void writeSpillFile(boolean isSpill) throws IOException { + private void writeSortedFile(boolean isLastFile) throws IOException { final ShuffleWriteMetrics writeMetricsToUse; - if (isSpill) { + if (isLastFile) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { // We're spilling, so bytes written should be counted towards spill rather than write. // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count // them towards shuffle bytes written. writeMetricsToUse = new ShuffleWriteMetrics(); - } else { - // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. - writeMetricsToUse = writeMetrics; } // This call performs the actual sort. @@ -221,16 +222,16 @@ private void writeSpillFile(boolean isSpill) throws IOException { if (writer != null) { writer.commitAndClose(); - // If `writeSpillFile()` was called from `closeAndGetSpills()` and no records were inserted, - // then the spill file might be empty. Note that it might be better to avoid calling - // writeSpillFile() in that case. + // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the file might be empty. Note that it might be better to avoid calling + // writeSortedFile() in that case. if (currentPartition != -1) { spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); spills.add(spillInfo); } } - if (isSpill) { + if (!isLastFile) { writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. @@ -244,19 +245,20 @@ private void writeSpillFile(boolean isSpill) throws IOException { */ @VisibleForTesting void spill() throws IOException { - final long threadId = Thread.currentThread().getId(); - logger.info("Thread " + threadId + " spilling sort data of " + - org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" + - (spills.size() + (spills.size() > 1 ? " times" : " time")) + " so far)"); + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() > 1 ? " times" : " time"); - writeSpillFile(true); + writeSortedFile(false); final long sorterMemoryUsage = sorter.getMemoryUsage(); sorter = null; shuffleMemoryManager.release(sorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - openSorter(); + initializeForWriting(); } private long getMemoryUsage() { @@ -405,7 +407,7 @@ public SpillInfo[] closeAndGetSpills() throws IOException { try { if (sorter != null) { // Do not count the final file towards the spill count. - writeSpillFile(false); + writeSortedFile(true); freeMemory(); } return spills.toArray(new SpillInfo[spills.size()]); From 57312c95b7b308622584dd25f5be6e0152d4118d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 13:42:22 -0700 Subject: [PATCH 79/92] Clarify fileBufferSize units --- .../shuffle/unsafe/UnsafeShuffleExternalSorter.java | 12 +++++++----- .../spark/shuffle/FileShuffleBlockResolver.scala | 2 +- .../util/collection/ExternalAppendOnlyMap.scala | 2 +- .../spark/util/collection/ExternalSorter.scala | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index d4db1d298e833..aaa4945eca9b2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -74,7 +74,7 @@ final class UnsafeShuffleExternalSorter { private final ShuffleWriteMetrics writeMetrics; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ - private final int fileBufferSize; + private final int fileBufferSizeBytes; /** * Memory pages that hold the records being sorted. The pages in this list are freed when @@ -108,8 +108,9 @@ public UnsafeShuffleExternalSorter( this.initialSize = initialSize; this.numPartitions = numPartitions; this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true); - // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units - this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.writeMetrics = writeMetrics; initializeForWriting(); } @@ -182,7 +183,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // around this, we pass a dummy no-op serializer. final SerializerInstance ser = DummySerializerInstance.INSTANCE; - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetricsToUse); + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); int currentPartition = -1; while (sortedRecords.hasNext()) { @@ -196,7 +197,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); } currentPartition = partition; - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetricsToUse); + writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6ad427bcac7f9..6c3b3080d2605 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index b850973145077..df2d6ad3b41a4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -90,7 +90,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7d5cf7b61e56a..3b9d14f9372b6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -110,7 +110,7 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) From 6276168cf9083e8cfc9e2ffb806145b97f6cf082 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 13:54:45 -0700 Subject: [PATCH 80/92] Remove ability to disable spilling in UnsafeShuffleExternalSorter. There's no obvious use-case for allowing users to disable spark.shuffle.spill and run out of memory. Because this configuration isn't deprecated as of this patch, I've added code to log a warning to let users know if their preference will be ignored by the new shuffle manager. --- .../unsafe/UnsafeShuffleExternalSorter.java | 47 +++++++------------ .../shuffle/unsafe/UnsafeShuffleManager.scala | 9 +++- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index aaa4945eca9b2..e674195b67d4f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -70,7 +70,6 @@ final class UnsafeShuffleExternalSorter { private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; - private final boolean spillingEnabled; private final ShuffleWriteMetrics writeMetrics; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ @@ -107,7 +106,6 @@ public UnsafeShuffleExternalSorter( this.taskContext = taskContext; this.initialSize = initialSize; this.numPartitions = numPartitions; - this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true); // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; @@ -121,12 +119,10 @@ public UnsafeShuffleExternalSorter( private void initializeForWriting() throws IOException { // TODO: move this sizing calculation logic into a static method of sorter: final long memoryRequested = initialSize * 8L; - if (spillingEnabled) { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); - if (memoryAcquired != memoryRequested) { - shuffleMemoryManager.release(memoryAcquired); - throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); - } + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } this.sorter = new UnsafeShuffleInMemorySorter(initialSize); @@ -291,7 +287,7 @@ public void cleanupAfterError() { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (spillingEnabled && sorter != null) { + if (sorter != null) { shuffleMemoryManager.release(sorter.getMemoryUsage()); sorter = null; } @@ -307,24 +303,19 @@ public void cleanupAfterError() { */ private boolean haveSpaceForRecord(int requiredSpace) { assert (requiredSpace > 0); - // The sort array will automatically expand when inserting a new record, so we only need to - // worry about it having free space when spilling is enabled. - final boolean sortBufferHasSpace = !spillingEnabled || sorter.hasSpaceForAnotherRecord(); - final boolean dataPageHasSpace = requiredSpace <= freeSpaceInCurrentPage; - return (sortBufferHasSpace && dataPageHasSpace); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); } /** - * Allocates more memory in order to insert an additional record. If spilling is enabled, this - * will request additional memory from the {@link ShuffleMemoryManager} and spill if the requested - * memory can not be obtained. If spilling is disabled, then this will allocate memory without - * coordinating with the ShuffleMemoryManager. + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. * * @param requiredSpace the required space in the data page, in bytes, including space for storing * the record size. */ private void allocateSpaceForRecord(int requiredSpace) throws IOException { - if (spillingEnabled && !sorter.hasSpaceForAnotherRecord()) { + if (!sorter.hasSpaceForAnotherRecord()) { logger.debug("Attempting to expand sort buffer"); final long oldSortBufferMemoryUsage = sorter.getMemoryUsage(); final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2; @@ -347,16 +338,14 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + PAGE_SIZE + ")"); } else { - if (spillingEnabled) { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { - shuffleMemoryManager.release(memoryAcquired); - spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { - shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); - } + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); } } currentPage = memoryManager.allocatePage(PAGE_SIZE); diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 4785e8c0f91a3..2da7691f5af8f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -113,7 +113,14 @@ private[spark] object UnsafeShuffleManager extends Logging { * * For more details on UnsafeShuffleManager's design, see SPARK-7081. */ -private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this is ignored by UnsafeShuffleManager; " + + "its optimized shuffles will continue to spill to disk when necessary.") + } + private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) From 4a2c7859f1c03a8ab904409535f0928d9926ddf1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 14:01:49 -0700 Subject: [PATCH 81/92] rename 'sort buffer' to 'pointer array' --- .../unsafe/UnsafeShuffleExternalSorter.java | 14 +++--- .../unsafe/UnsafeShuffleInMemorySorter.java | 49 ++++++++++--------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index e674195b67d4f..7a137aa85cf8d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -316,16 +316,16 @@ private boolean haveSpaceForRecord(int requiredSpace) { */ private void allocateSpaceForRecord(int requiredSpace) throws IOException { if (!sorter.hasSpaceForAnotherRecord()) { - logger.debug("Attempting to expand sort buffer"); - final long oldSortBufferMemoryUsage = sorter.getMemoryUsage(); - final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowSortBuffer); - if (memoryAcquired < memoryToGrowSortBuffer) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { shuffleMemoryManager.release(memoryAcquired); spill(); } else { - sorter.expandSortBuffer(); - shuffleMemoryManager.release(oldSortBufferMemoryUsage); + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); } } if (requiredSpace > freeSpaceInCurrentPage) { diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java index b4055141a4ec5..5bab501da9364 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -32,33 +32,38 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { } private static final SortComparator SORT_COMPARATOR = new SortComparator(); - private long[] sortBuffer; + /** + * An array of record pointers and partition ids that have been encoded by + * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating + * records. + */ + private long[] pointerArray; /** - * The position in the sort buffer where new records can be inserted. + * The position in the pointer array where new records can be inserted. */ - private int sortBufferInsertPosition = 0; + private int pointerArrayInsertPosition = 0; public UnsafeShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); - this.sortBuffer = new long[initialSize]; + this.pointerArray = new long[initialSize]; this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); } - public void expandSortBuffer() { - final long[] oldBuffer = sortBuffer; + public void expandPointerArray() { + final long[] oldArray = pointerArray; // Guard against overflow: - final int newLength = oldBuffer.length * 2 > 0 ? (oldBuffer.length * 2) : Integer.MAX_VALUE; - sortBuffer = new long[newLength]; - System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); } public boolean hasSpaceForAnotherRecord() { - return sortBufferInsertPosition + 1 < sortBuffer.length; + return pointerArrayInsertPosition + 1 < pointerArray.length; } public long getMemoryUsage() { - return sortBuffer.length * 8L; + return pointerArray.length * 8L; } /** @@ -73,15 +78,15 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (sortBuffer.length == Integer.MAX_VALUE) { - throw new IllegalStateException("Sort buffer has reached maximum size"); + if (pointerArray.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Sort pointer array has reached maximum size"); } else { - expandSortBuffer(); + expandPointerArray(); } } - sortBuffer[sortBufferInsertPosition] = + pointerArray[pointerArrayInsertPosition] = PackedRecordPointer.packPointer(recordPointer, partitionId); - sortBufferInsertPosition++; + pointerArrayInsertPosition++; } /** @@ -89,14 +94,14 @@ public void insertRecord(long recordPointer, int partitionId) { */ public static final class UnsafeShuffleSorterIterator { - private final long[] sortBuffer; + private final long[] pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public UnsafeShuffleSorterIterator(int numRecords, long[] sortBuffer) { + public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { this.numRecords = numRecords; - this.sortBuffer = sortBuffer; + this.pointerArray = pointerArray; } public boolean hasNext() { @@ -104,7 +109,7 @@ public boolean hasNext() { } public void loadNext() { - packedRecordPointer.set(sortBuffer[position]); + packedRecordPointer.set(pointerArray[position]); position++; } } @@ -113,7 +118,7 @@ public void loadNext() { * Return an iterator over record pointers in sorted order. */ public UnsafeShuffleSorterIterator getSortedIterator() { - sorter.sort(sortBuffer, 0, sortBufferInsertPosition, SORT_COMPARATOR); - return new UnsafeShuffleSorterIterator(sortBufferInsertPosition, sortBuffer); + sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); + return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); } } From e3b88550b69bfe8f07b1c46ef169306db3ab766a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 14:28:49 -0700 Subject: [PATCH 82/92] Cleanup in UnsafeShuffleWriter --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 63 +++++++++++++------ .../unsafe/UnsafeShuffleWriterSuite.java | 7 ++- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 7544ebbfeaad5..e2a942a425e87 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -102,7 +102,7 @@ public UnsafeShuffleWriter( UnsafeShuffleHandle handle, int mapId, TaskContext taskContext, - SparkConf sparkConf) { + SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) { throw new IllegalArgumentException( @@ -123,27 +123,29 @@ public UnsafeShuffleWriter( this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + open(); } + /** + * This convenience method should only be called in test code. + */ + @VisibleForTesting public void write(Iterator> records) throws IOException { write(JavaConversions.asScalaIterator(records)); } @Override public void write(scala.collection.Iterator> records) throws IOException { + boolean success = false; try { while (records.hasNext()) { insertRecordIntoSorter(records.next()); } closeAndWriteOutput(); - } catch (Exception e) { - // Unfortunately, we have to catch Exception here in order to ensure proper cleanup after - // errors because Spark's Scala code, or users' custom Serializers, might throw arbitrary - // unchecked exceptions. - try { + success = true; + } finally { + if (!success) { sorter.cleanupAfterError(); - } finally { - throw new IOException("Error during shuffle write", e); } } } @@ -165,9 +167,6 @@ private void open() throws IOException { @VisibleForTesting void closeAndWriteOutput() throws IOException { - if (sorter == null) { - open(); - } serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); @@ -187,10 +186,7 @@ void closeAndWriteOutput() throws IOException { } @VisibleForTesting - void insertRecordIntoSorter(Product2 record) throws IOException{ - if (sorter == null) { - open(); - } + void insertRecordIntoSorter(Product2 record) throws IOException { final K key = record._1(); final int partitionId = partitioner.getPartition(key); serBuffer.reset(); @@ -275,15 +271,29 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { } } + /** + * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, + * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in + * cases where the IO compression codec does not support concatenation of compressed data, or in + * cases where users have explicitly disabled use of {@code transferTo} in order to work around + * kernel bugs. + * + * @param spills the spills to merge. + * @param outputFile the file to write the merged data to. + * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. + * @return the partition lengths in the merged file. + */ private long[] mergeSpillsWithFileStream( SpillInfo[] spills, File outputFile, @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; OutputStream mergedFileOutputStream = null; + boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new FileInputStream(spills[i].file); @@ -311,22 +321,34 @@ private long[] mergeSpillsWithFileStream( mergedFileOutputStream.close(); partitionLengths[partition] = (outputFile.length() - initialFileLength); } + threwException = false; } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. for (InputStream stream : spillInputStreams) { - Closeables.close(stream, false); + Closeables.close(stream, threwException); } - Closeables.close(mergedFileOutputStream, false); + Closeables.close(mergedFileOutputStream, threwException); } return partitionLengths; } + /** + * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. + * This is only safe when the IO compression codec and serializer support concatenation of + * serialized streams. + * + * @return the partition lengths in the merged file. + */ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; FileChannel mergedFileOutputChannel = null; + boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); @@ -368,12 +390,15 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th "to disable this NIO feature." ); } + threwException = false; } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. for (int i = 0; i < spills.length; i++) { assert(spillInputChannelPositions[i] == spills[i].file.length()); - Closeables.close(spillInputChannels[i], false); + Closeables.close(spillInputChannels[i], threwException); } - Closeables.close(mergedFileOutputChannel, false); + Closeables.close(mergedFileOutputChannel, threwException); } return partitionLengths; } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 61511de6a5219..730d265c87f88 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -194,7 +194,8 @@ public Tuple2 answer( when(shuffleDep.partitioner()).thenReturn(hashPartitioner); } - private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { + private UnsafeShuffleWriter createWriter( + boolean transferToEnabled) throws IOException { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter( blockManager, @@ -242,12 +243,12 @@ private List> readRecordsFromFile() throws IOException { } @Test(expected=IllegalStateException.class) - public void mustCallWriteBeforeSuccessfulStop() { + public void mustCallWriteBeforeSuccessfulStop() throws IOException { createWriter(false).stop(true); } @Test - public void doNotNeedToCallWriteBeforeUnsuccessfulStop() { + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { createWriter(false).stop(false); } From c2ce78e221aceb322f41edc51581750b045654a2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 14:31:27 -0700 Subject: [PATCH 83/92] Fix a missed usage of MAX_PARTITION_ID --- .../org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index e2a942a425e87..f097af35d87f8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -104,10 +104,10 @@ public UnsafeShuffleWriter( TaskContext taskContext, SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); - if (numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) { + if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - PackedRecordPointer.MAXIMUM_PARTITION_ID + " reduce partitions"); + UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; From 5e189c6f9347dab9aa8fada32117b6c74780dc06 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 15:17:11 -0700 Subject: [PATCH 84/92] Track time spend closing / flushing files; split TimeTrackingOutputStream into separate file. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 3 ++- .../TimeTrackingOutputStream.java | 25 +++++++++++-------- .../spark/storage/BlockObjectWriter.scala | 24 +++--------------- 3 files changed, 20 insertions(+), 32 deletions(-) rename core/src/main/java/org/apache/spark/{shuffle/unsafe => storage}/TimeTrackingOutputStream.java (75%) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index f097af35d87f8..9c288dc7e8f77 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -50,6 +50,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -301,7 +302,7 @@ private long[] mergeSpillsWithFileStream( for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = outputFile.length(); mergedFileOutputStream = - new TimeTrackingFileOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); if (compressionCodec != null) { mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java similarity index 75% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java rename to core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java index 8b5ba49e67204..0cd3c7d242660 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -15,25 +15,23 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.storage; -import org.apache.spark.executor.ShuffleWriteMetrics; - -import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import org.apache.spark.executor.ShuffleWriteMetrics; + /** - * Intercepts write calls and tracks total time spent writing. + * Intercepts write calls and tracks total time spent writing in order to update shuffle write + * metrics. Not thread safe. */ -final class TimeTrackingFileOutputStream extends OutputStream { +public final class TimeTrackingOutputStream extends OutputStream { private final ShuffleWriteMetrics writeMetrics; - private final FileOutputStream outputStream; + private final OutputStream outputStream; - public TimeTrackingFileOutputStream( - ShuffleWriteMetrics writeMetrics, - FileOutputStream outputStream) { + public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { this.writeMetrics = writeMetrics; this.outputStream = outputStream; } @@ -49,7 +47,8 @@ public void write(int b) throws IOException { public void write(byte[] b) throws IOException { final long startTime = System.nanoTime(); outputStream.write(b); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); } + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } @Override public void write(byte[] b, int off, int len) throws IOException { @@ -60,11 +59,15 @@ public void write(byte[] b, int off, int len) throws IOException { @Override public void flush() throws IOException { + final long startTime = System.nanoTime(); outputStream.flush(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); } @Override public void close() throws IOException { + final long startTime = System.nanoTime(); outputStream.close(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 8bc4e205bc3c6..a33f22ef52687 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter( extends BlockObjectWriter(blockId) with Logging { - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ - private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - override def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int): Unit = { - callWithTiming(out.write(b, off, len)) - } - override def close(): Unit = out.close() - override def flush(): Unit = out.flush() - } /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -136,7 +126,7 @@ private[spark] class DiskBlockObjectWriter( throw new IllegalStateException("Writer already closed. Cannot be reopened.") } fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(fos) + ts = new TimeTrackingOutputStream(writeMetrics, fos) channel = fos.getChannel() bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializerInstance.serializeStream(bs) @@ -150,9 +140,9 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - callWithTiming { - fos.getFD.sync() - } + val start = System.nanoTime() + fos.getFD.sync() + writeMetrics.incShuffleWriteTime(System.nanoTime() - start) } } { objOut.close() @@ -251,12 +241,6 @@ private[spark] class DiskBlockObjectWriter( reportedPosition = pos } - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) - } - // For testing private[spark] override def flush() { objOut.flush() From df07699b186130e494aea41a56dcf13740b0956b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 15:35:23 -0700 Subject: [PATCH 85/92] Attempt to clarify confusing metrics update code --- .../unsafe/UnsafeShuffleExternalSorter.java | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 7a137aa85cf8d..9e9ed94b7890c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -214,7 +214,6 @@ private void writeSortedFile(boolean isLastFile) throws IOException { recordReadPosition += toTransfer; dataRemaining -= toTransfer; } - // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } @@ -229,11 +228,23 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } } - if (!isLastFile) { - writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); + if (!isLastFile) { // i.e. this is a spill file + // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records + // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter + // relies on its `recordWritten()` method being called in order to trigger periodic updates to + // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that + // counter at a higher-level, then the in-progress metrics for records written and bytes + // written would get out of sync. + // + // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter; + // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those + // metrics to the true write metrics here. The reason for performing this copying is so that + // we can avoid reporting spilled bytes as shuffle write bytes. + // + // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. - // writeMetrics.incShuffleWriteTime(writeMetricsToUse.shuffleWriteTime()); + writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); } } From de40b9d7109536feceed1aff219f1cb7a2da64a8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 15:38:57 -0700 Subject: [PATCH 86/92] More comments to try to explain metrics code --- .../apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 9c288dc7e8f77..22791965dd962 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -258,8 +258,11 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { logger.debug("Using slow merge"); partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); } - // The final shuffle spill's write would have directly updated shuffleBytesWritten, so - // we need to decrement to avoid double-counting this write. + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); writeMetrics.incShuffleBytesWritten(outputFile.length()); return partitionLengths; From 4023fa403d4fb6396aaa7ee75cb389f211328960 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 15:58:37 -0700 Subject: [PATCH 87/92] Add @Private annotation to some Java classes. --- .../org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java | 2 ++ .../java/org/apache/spark/storage/TimeTrackingOutputStream.java | 2 ++ 2 files changed, 4 insertions(+) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 22791965dd962..ad7eb04afcd8c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -36,6 +36,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.*; +import org.apache.spark.annotation.Private; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.LZFCompressionCodec; @@ -54,6 +55,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.TaskMemoryManager; +@Private public class UnsafeShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java index 0cd3c7d242660..dc2aa30466cc6 100644 --- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -20,12 +20,14 @@ import java.io.IOException; import java.io.OutputStream; +import org.apache.spark.annotation.Private; import org.apache.spark.executor.ShuffleWriteMetrics; /** * Intercepts write calls and tracks total time spent writing in order to update shuffle write * metrics. Not thread safe. */ +@Private public final class TimeTrackingOutputStream extends OutputStream { private final ShuffleWriteMetrics writeMetrics; From 51812a75caaed811153223e01a82407fc71102f6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 17:03:15 -0700 Subject: [PATCH 88/92] Change shuffle manager sort name to tungsten-sort --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 2 +- .../apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala | 4 ++-- .../org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 8c40bc93863b2..a5d831c7e68ad 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -314,7 +314,7 @@ object SparkEnv extends Logging { val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", - "unsafe" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") + "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 2da7691f5af8f..ce684fbe59d79 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -117,8 +117,8 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage if (!conf.getBoolean("spark.shuffle.spill", true)) { logWarning( - "spark.shuffle.spill was set to false, but this is ignored by UnsafeShuffleManager; " + - "its optimized shuffles will continue to spill to disk when necessary.") + "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + + "manager; its optimized shuffles will continue to spill to disk when necessary.") } diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala index f7eefa2a3f40c..e68261a730d3a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -25,7 +25,7 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. override def beforeAll() { - conf.set("spark.shuffle.manager", "unsafe") + conf.set("spark.shuffle.manager", "tungsten-sort") // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort // shuffle records. conf.set("spark.shuffle.memoryFraction", "0.5") From 52a99819506aa32e3d146b639cf597948f14c8cd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 12 May 2015 18:31:45 -0700 Subject: [PATCH 89/92] Fix some bugs in the address packing code. The problem is that TaskMemoryManager expects offsets to include the page base address whereas PackedRecordPointer did not. --- .../shuffle/unsafe/PackedRecordPointer.java | 10 ++-- .../unsafe/PackedRecordPointerSuite.java | 16 ++++-- .../unsafe/memory/TaskMemoryManager.java | 52 +++++++++++++++---- .../unsafe/memory/TaskMemoryManagerSuite.java | 7 ++- 4 files changed, 63 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java index 6d61b1b9e34da..4ee6a82c0423e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -68,9 +68,8 @@ public static long packPointer(long recordPointer, int partitionId) { assert (partitionId <= MAXIMUM_PARTITION_ID); // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. - final int pageNumber = (int) ((recordPointer & MASK_LONG_UPPER_13_BITS) >>> 51); - final long compressedAddress = - (((long) pageNumber) << 27) | (recordPointer & MASK_LONG_LOWER_27_BITS); + final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24; + final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS); return (((long) partitionId) << 40) | compressedAddress; } @@ -85,9 +84,8 @@ public int getPartitionId() { } public long getRecordPointer() { - final long compressedAddress = packedRecordPointer & MASK_LONG_LOWER_40_BITS; - final long pageNumber = (compressedAddress << 24) & MASK_LONG_UPPER_13_BITS; - final long offsetInPage = compressedAddress & MASK_LONG_LOWER_27_BITS; + final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS; + final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS; return pageNumber | offsetInPage; } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java index 4fda87ab57c49..db9e82759090a 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -34,11 +34,15 @@ public void heap() { new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); final MemoryBlock page0 = memoryManager.allocatePage(100); final MemoryBlock page1 = memoryManager.allocatePage(100); - final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); assertEquals(360, packedPointer.getPartitionId()); - assertEquals(addressInPage1, packedPointer.getRecordPointer()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); memoryManager.cleanUpAllAllocatedMemory(); } @@ -48,11 +52,15 @@ public void offHeap() { new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); final MemoryBlock page0 = memoryManager.allocatePage(100); final MemoryBlock page1 = memoryManager.allocatePage(100); - final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); assertEquals(360, packedPointer.getPartitionId()); - assertEquals(addressInPage1, packedPointer.getRecordPointer()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); memoryManager.cleanUpAllAllocatedMemory(); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index cfd54035bee99..2906ac8abad1a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -48,10 +48,18 @@ public final class TaskMemoryManager { private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); - /** - * The number of entries in the page table. - */ - private static final int PAGE_TABLE_SIZE = 1 << 13; + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** Maximum supported data page size */ + private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -102,8 +110,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (size >= (1L << 51)) { - throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + if (size > MAXIMUM_PAGE_SIZE) { + throw new IllegalArgumentException( + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); } final int pageNumber; @@ -168,15 +177,36 @@ public void free(MemoryBlock memory) { /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (!inHeap) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); + } return encodePageNumberAndOffset(page.pageNumber, offsetInPage); } @VisibleForTesting public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); } /** @@ -185,7 +215,7 @@ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) */ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { - final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final Object page = pageTable[pageNumber].getBaseObject(); assert (page != null); @@ -200,11 +230,13 @@ public Object getPage(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { - final long offsetInPage = (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); if (inHeap) { return offsetInPage; } else { - final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); return pageTable[pageNumber].getBaseOffset() + offsetInPage; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java index 8ace8625abb64..06fb081183659 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -43,9 +43,12 @@ public void encodePageNumberAndOffsetOffHeap() { final TaskMemoryManager manager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); final MemoryBlock dataPage = manager.allocatePage(256); - final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); Assert.assertEquals(null, manager.getPage(encodedAddress)); - Assert.assertEquals(dataPage.getBaseOffset() + 64, manager.getOffsetInPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); } @Test From d494ffe50ffd1b6f90e89a9a9947a4362c6ac531 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 13 May 2015 13:20:32 -0700 Subject: [PATCH 90/92] Fix deserialization of JavaSerializer instances. This caused a failure in a new test; this problem occurs when calls ShuffledRDD.setSerializer() with a JavaSerializer. --- .../spark/serializer/JavaSerializer.scala | 2 ++ .../serializer/JavaSerializerSuite.scala | 29 +++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index dfbde7c8a1b0d..698d1384d580d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -121,6 +121,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) + protected def this() = this(new SparkConf()) // For deserialization only + override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala new file mode 100644 index 0000000000000..ed4d8ce632e16 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.spark.SparkConf +import org.scalatest.FunSuite + +class JavaSerializerSuite extends FunSuite { + test("JavaSerializer instances are serializable") { + val serializer = new JavaSerializer(new SparkConf()) + val instance = serializer.newInstance() + instance.deserialize[JavaSerializer](instance.serialize(serializer)) + } +} From 7610f2f7613050e5b32eb9314245d79c0dac7b94 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 13 May 2015 13:21:18 -0700 Subject: [PATCH 91/92] Add tests for proper cleanup of shuffle data. --- .../shuffle/sort/SortShuffleManager.scala | 2 +- .../shuffle/unsafe/UnsafeShuffleManager.scala | 26 +++++-- .../shuffle/unsafe/UnsafeShuffleSuite.scala | 72 ++++++++++++++++++- 3 files changed, 92 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 15842941daaab..d7fab351ca3b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager true } - override def shuffleBlockResolver: IndexShuffleBlockResolver = { + override val shuffleBlockResolver: IndexShuffleBlockResolver = { indexShuffleBlockResolver } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index ce684fbe59d79..f2bfef376d3ca 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -17,6 +17,9 @@ package org.apache.spark.shuffle.unsafe +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ @@ -25,7 +28,7 @@ import org.apache.spark.shuffle.sort.SortShuffleManager /** * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. */ -private class UnsafeShuffleHandle[K, V]( +private[spark] class UnsafeShuffleHandle[K, V]( shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, V]) @@ -121,8 +124,10 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage "manager; its optimized shuffles will continue to spill to disk when necessary.") } - private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + private[this] val shufflesThatFellBackToSortShuffle = + Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) + private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. @@ -158,8 +163,8 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage context: TaskContext): ShuffleWriter[K, V] = { handle match { case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) val env = SparkEnv.get - // TODO: do we need to do anything to register the shuffle here? new UnsafeShuffleWriter( env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], @@ -170,17 +175,26 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage context, env.conf) case other => + shufflesThatFellBackToSortShuffle.add(handle.shuffleId) sortShuffleManager.getWriter(handle, mapId, context) } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - // TODO: need to do something here for our unsafe path - sortShuffleManager.unregisterShuffle(shuffleId) + if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { + sortShuffleManager.unregisterShuffle(shuffleId) + } else { + Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } } - override def shuffleBlockResolver: ShuffleBlockResolver = { + override val shuffleBlockResolver: IndexShuffleBlockResolver = { sortShuffleManager.shuffleBlockResolver } diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala index e68261a730d3a..64569f1c60927 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -17,9 +17,17 @@ package org.apache.spark.shuffle.unsafe -import org.apache.spark.ShuffleSuite +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. @@ -30,4 +38,66 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // shuffle records. conf.set("spark.shuffle.memoryFraction", "0.5") } + + test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } + + test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the old SortShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } } From ef0a86e41e9b390e6c0d60a6ed2105dbc54431f7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 13 May 2015 13:50:11 -0700 Subject: [PATCH 92/92] Fix scalastyle errors --- .../apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala index 64569f1c60927..6351539e91e97 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.unsafe +import java.io.File + import scala.collection.JavaConverters._ import org.apache.commons.io.FileUtils @@ -51,7 +53,7 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { .setSerializer(new KryoSerializer(myConf)) val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles = + def getAllFiles: Set[File] = FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet val filesBeforeShuffle = getAllFiles // Force the shuffle to be performed @@ -82,7 +84,7 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { .setSerializer(new JavaSerializer(myConf)) val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles = + def getAllFiles: Set[File] = FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet val filesBeforeShuffle = getAllFiles // Force the shuffle to be performed