From f6d06adf05afa9c5386dc2396c94e7a98730289f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 22 Oct 2015 09:46:30 -0700 Subject: [PATCH] [SPARK-10708] Consolidate sort shuffle implementations There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together. Author: Josh Rosen Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations. --- .../sort/BypassMergeSortShuffleWriter.java | 106 +++++-- .../{unsafe => sort}/PackedRecordPointer.java | 2 +- .../ShuffleExternalSorter.java} | 28 +- .../ShuffleInMemorySorter.java} | 16 +- .../ShuffleSortDataFormat.java} | 8 +- .../shuffle/sort/SortShuffleFileWriter.java | 53 ---- .../shuffle/{unsafe => sort}/SpillInfo.java | 4 +- .../{unsafe => sort}/UnsafeShuffleWriter.java | 12 +- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../shuffle/sort/SortShuffleManager.scala | 175 +++++++++-- .../shuffle/sort/SortShuffleWriter.scala | 28 +- .../shuffle/unsafe/UnsafeShuffleManager.scala | 202 ------------- .../spark/util/collection/ChainedBuffer.scala | 146 ---------- .../util/collection/ExternalSorter.scala | 35 +-- .../PartitionedSerializedPairBuffer.scala | 273 ------------------ .../PackedRecordPointerSuite.java | 5 +- .../ShuffleInMemorySorterSuite.java} | 16 +- .../UnsafeShuffleWriterSuite.java | 10 +- .../org/apache/spark/SortShuffleSuite.scala | 65 +++++ .../spark/scheduler/DAGSchedulerSuite.scala | 6 +- .../BypassMergeSortShuffleWriterSuite.scala | 64 ++-- .../SortShuffleManagerSuite.scala} | 30 +- .../shuffle/sort/SortShuffleWriterSuite.scala | 45 --- .../shuffle/unsafe/UnsafeShuffleSuite.scala | 102 ------- .../util/collection/ChainedBufferSuite.scala | 144 --------- ...PartitionedSerializedPairBufferSuite.scala | 148 ---------- docs/configuration.md | 7 +- project/MimaExcludes.scala | 9 +- .../apache/spark/sql/execution/Exchange.scala | 23 +- .../execution/UnsafeRowSerializerSuite.scala | 9 +- 30 files changed, 456 insertions(+), 1317 deletions(-) rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/PackedRecordPointer.java (98%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleExternalSorter.java => sort/ShuffleExternalSorter.java} (95%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleInMemorySorter.java => sort/ShuffleInMemorySorter.java} (88%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleSortDataFormat.java => sort/ShuffleSortDataFormat.java} (86%) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/SpillInfo.java (90%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/UnsafeShuffleWriter.java (98%) delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala rename core/src/test/java/org/apache/spark/shuffle/{unsafe => sort}/PackedRecordPointerSuite.java (96%) rename core/src/test/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleInMemorySorterSuite.java => sort/ShuffleInMemorySorterSuite.java} (87%) rename core/src/test/java/org/apache/spark/shuffle/{unsafe => sort}/UnsafeShuffleWriterSuite.java (98%) rename core/src/test/scala/org/apache/spark/shuffle/{unsafe/UnsafeShuffleManagerSuite.scala => sort/SortShuffleManagerSuite.scala} (80%) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index f5d80bbcf3557..ee82d679935c0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,21 +21,30 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import javax.annotation.Nullable; +import scala.None$; +import scala.Option; import scala.Product2; import scala.Tuple2; import scala.collection.Iterator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -62,7 +71,7 @@ *

* There have been proposals to completely remove this code path; see SPARK-6026 for details. */ -final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { +final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); @@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final BlockManager blockManager; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; private final Serializer serializer; + private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; + @Nullable private MapStatus mapStatus; + private long[] partitionLengths; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; public BypassMergeSortShuffleWriter( - SparkConf conf, BlockManager blockManager, - Partitioner partitioner, - ShuffleWriteMetrics writeMetrics, - Serializer serializer) { + IndexShuffleBlockResolver shuffleBlockResolver, + BypassMergeSortShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf conf) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); - this.numPartitions = partitioner.numPartitions(); this.blockManager = blockManager; - this.partitioner = partitioner; - this.writeMetrics = writeMetrics; - this.serializer = serializer; + final ShuffleDependency dep = handle.dependency(); + this.mapId = mapId; + this.shuffleId = dep.shuffleId(); + this.partitioner = dep.partitioner(); + this.numPartitions = partitioner.numPartitions(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.serializer = Serializer.getSerializer(dep.serializer()); + this.shuffleBlockResolver = shuffleBlockResolver; } @Override - public void insertAll(Iterator> records) throws IOException { + public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -124,13 +154,24 @@ public void insertAll(Iterator> records) throws IOException { for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } + + partitionLengths = + writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId)); + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } - @Override - public long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException { + @VisibleForTesting + long[] getPartitionLengths() { + return partitionLengths; + } + + /** + * Concatenate all of the per-partition files into a single combined file. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). + */ + private long[] writePartitionedFile(File outputFile) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { @@ -165,18 +206,33 @@ public long[] writePartitionedFile( } @Override - public void stop() throws IOException { - if (partitionWriters != null) { - try { - for (DiskBlockObjectWriter writer : partitionWriters) { - // This method explicitly does _not_ throw exceptions: - File file = writer.revertPartialWritesAndClose(); - if (!file.delete()) { - logger.error("Error while deleting file {}", file.getAbsolutePath()); + public Option stop(boolean success) { + if (stopping) { + return None$.empty(); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + if (partitionWriters != null) { + try { + for (DiskBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + File file = writer.revertPartialWritesAndClose(); + if (!file.delete()) { + logger.error("Error while deleting file {}", file.getAbsolutePath()); + } + } + } finally { + partitionWriters = null; } } - } finally { - partitionWriters = null; + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return None$.empty(); } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java similarity index 98% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java rename to core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index 4ee6a82c0423e..c11711966fa8c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java similarity index 95% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index e73ba39468828..85fdaa8115fa3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.File; @@ -48,7 +48,7 @@ *

* Incoming records are appended to data pages. When all records have been inserted (or when the * current thread's shuffle memory limit is reached), the in-memory records are sorted according to - * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then * written to a single output file (or multiple files, if we've spilled). The format of the output * files is the same as the format of the final output file written by * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are @@ -59,9 +59,9 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class UnsafeShuffleExternalSorter { +final class ShuffleExternalSorter { - private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; @@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter { private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; + private long numRecordsInsertedSinceLastSpill = 0; + + /** Force this sorter to spill when there are this many elements in memory. For testing only */ + private final long numElementsForSpillThreshold; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter { private long peakMemoryUsedBytes; // These variables are reset after spilling: - @Nullable private UnsafeShuffleInMemorySorter inMemSorter; + @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; - public UnsafeShuffleExternalSorter( + public ShuffleExternalSorter( TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, @@ -117,6 +121,8 @@ public UnsafeShuffleExternalSorter( this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.numElementsForSpillThreshold = + conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.pageSizeBytes = (int) Math.min( PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); this.maxRecordSizeBytes = pageSizeBytes - 4; @@ -140,7 +146,8 @@ private void initializeForWriting() throws IOException { throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } - this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(initialSize); + numRecordsInsertedSinceLastSpill = 0; } /** @@ -166,7 +173,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } // This call performs the actual sort. - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = inMemSorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this @@ -406,6 +413,10 @@ public void insertRecord( int lengthInBytes, int partitionId) throws IOException { + if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) { + spill(); + } + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int totalSpaceRequired = lengthInBytes + 4; @@ -453,6 +464,7 @@ public void insertRecord( recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, partitionId); + numRecordsInsertedSinceLastSpill += 1; } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java similarity index 88% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 5bab501da9364..a8dee6c6101c1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Comparator; import org.apache.spark.util.collection.Sorter; -final class UnsafeShuffleInMemorySorter { +final class ShuffleInMemorySorter { private final Sorter sorter; private static final class SortComparator implements Comparator { @@ -44,10 +44,10 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int pointerArrayInsertPosition = 0; - public UnsafeShuffleInMemorySorter(int initialSize) { + public ShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); this.pointerArray = new long[initialSize]; - this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + this.sorter = new Sorter(ShuffleSortDataFormat.INSTANCE); } public void expandPointerArray() { @@ -92,14 +92,14 @@ public void insertRecord(long recordPointer, int partitionId) { /** * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. */ - public static final class UnsafeShuffleSorterIterator { + public static final class ShuffleSorterIterator { private final long[] pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + public ShuffleSorterIterator(int numRecords, long[] pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -117,8 +117,8 @@ public void loadNext() { /** * Return an iterator over record pointers in sorted order. */ - public UnsafeShuffleSorterIterator getSortedIterator() { + public ShuffleSorterIterator getSortedIterator() { sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); - return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java similarity index 86% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index a66d74ee44782..8a1e5aec6ff0e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import org.apache.spark.util.collection.SortDataFormat; -final class UnsafeShuffleSortDataFormat extends SortDataFormat { +final class ShuffleSortDataFormat extends SortDataFormat { - public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); - private UnsafeShuffleSortDataFormat() { } + private ShuffleSortDataFormat() { } @Override public PackedRecordPointer getKey(long[] data, int pos) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java deleted file mode 100644 index 656ea0401a144..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort; - -import java.io.File; -import java.io.IOException; - -import scala.Product2; -import scala.collection.Iterator; - -import org.apache.spark.annotation.Private; -import org.apache.spark.TaskContext; -import org.apache.spark.storage.BlockId; - -/** - * Interface for objects that {@link SortShuffleWriter} uses to write its output files. - */ -@Private -public interface SortShuffleFileWriter { - - void insertAll(Iterator> records) throws IOException; - - /** - * Write all the data added into this shuffle sorter into a file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * binary files if we decided to avoid merge-sorting. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) - */ - long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException; - - void stop() throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java similarity index 90% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java rename to core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java index 7bac0dc0bbeb6..df9f7b7abe028 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.File; import org.apache.spark.storage.TempShuffleBlockId; /** - * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + * Metadata for a block of data written by {@link ShuffleExternalSorter}. */ final class SpillInfo { final long[] partitionLengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java similarity index 98% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index fdb309e365f69..e8f050cb2dab1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.*; @@ -80,7 +80,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; @Nullable private MapStatus mapStatus; - @Nullable private UnsafeShuffleExternalSorter sorter; + @Nullable private ShuffleExternalSorter sorter; private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ @@ -104,15 +104,15 @@ public UnsafeShuffleWriter( IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, - UnsafeShuffleHandle handle, + SerializedShuffleHandle handle, int mapId, TaskContext taskContext, SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); - if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; @@ -195,7 +195,7 @@ public void write(scala.collection.Iterator> records) throws IOEx private void open() throws IOException { assert (sorter == null); - sorter = new UnsafeShuffleExternalSorter( + sorter = new ShuffleExternalSorter( memoryManager, shuffleMemoryManager, blockManager, diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index c32998345145a..704158bfc7643 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -330,7 +330,7 @@ object SparkEnv extends Logging { val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", - "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") + "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 9df4e551669cc..1105167d39d8d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,9 +19,53 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark._ +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ +/** + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * Sort-based shuffle has two different write paths for producing its map output files: + * + * - Serialized sorting: used when all three of the following conditions hold: + * 1. The shuffle dependency specifies no aggregation or output ordering. + * 2. The shuffle serializer supports relocation of serialized values (this is currently + * supported by KryoSerializer and Spark SQL's custom serializers). + * 3. The shuffle produces fewer than 16777216 output partitions. + * - Deserialized sorting: used to handle all other cases. + * + * ----------------------- + * Serialized sorting mode + * ----------------------- + * + * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the + * shuffle writer and are buffered in a serialized form during sorting. This write path implements + * several optimizations: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on these optimizations, see SPARK-7081. + */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { if (!conf.getBoolean("spark.shuffle.spill", true)) { @@ -30,8 +74,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager " Shuffle will continue to spill to disk when necessary.") } - private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) - private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() + /** + * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + */ + private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. @@ -40,7 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - new BaseShuffleHandle(shuffleId, numMaps, dependency) + if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } } /** @@ -52,38 +115,114 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - // We currently use the same block store shuffle fetcher as the hash-based shuffle. new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) - : ShuffleWriter[K, V] = { - val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] - shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) - new SortShuffleWriter( - shuffleBlockResolver, baseShuffleHandle, mapId, context) + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + numMapsForShuffle.putIfAbsent( + handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + val env = SparkEnv.get + handle match { + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + bypassMergeSortHandle, + mapId, + context, + env.conf) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) + } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shuffleMapNumber.containsKey(shuffleId)) { - val numMaps = shuffleMapNumber.remove(shuffleId) - (0 until numMaps).map{ mapId => + Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - indexShuffleBlockResolver - } - /** Shut down this ShuffleManager. */ override def stop(): Unit = { shuffleBlockResolver.stop() } } + +private[spark] object SortShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that SortShuffleManager supports when + * buffering map outputs in a serialized form. This is an extreme defensive programming measure, + * since it's extremely unlikely that a single shuffle produces over 16 million output partitions. + * */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = + PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use an optimized serialized shuffle + * path or whether it should fall back to the original path that operates on deserialized objects. + */ + def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { + val shufId = dependency.shuffleId + val numPartitions = dependency.partitioner.numPartitions + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug( + s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + false + } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions") + false + } else { + log.debug(s"Can use serialized shuffle for shuffle $shufId") + true + } + } +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * serialized shuffle. + */ +private[spark] class SerializedShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * bypass merge sort shuffle path. + */ +private[spark] class BypassMergeSortShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 5865e7640c1cf..bbd9c1ab53cd8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: SortShuffleFileWriter[K, V] = null + private var sorter: ExternalSorter[K, V, _] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C]( require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - } else if (SortShuffleWriter.shouldBypassMergeSort( - SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need local aggregation and sorting, write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, - writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side @@ -111,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C]( } private[spark] object SortShuffleWriter { - def shouldBypassMergeSort( - conf: SparkConf, - numPartitions: Int, - aggregator: Option[Aggregator[_, _, _]], - keyOrdering: Option[Ordering[_]]): Boolean = { - val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") + false + } else { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + dep.partitioner.numPartitions <= bypassMergeThreshold + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala deleted file mode 100644 index 75f22f642b9d1..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.util.Collections -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark._ -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.SortShuffleManager - -/** - * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. - */ -private[spark] class UnsafeShuffleHandle[K, V]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { -} - -private[spark] object UnsafeShuffleManager extends Logging { - - /** - * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. - */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 - - /** - * Helper method for determining whether a shuffle should use the optimized unsafe shuffle - * path or whether it should fall back to the original sort-based shuffle. - */ - def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { - val shufId = dependency.shuffleId - val serializer = Serializer.getSerializer(dependency.serializer) - if (!serializer.supportsRelocationOfSerializedObjects) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + - s"${serializer.getClass.getName}, does not support object relocation") - false - } else if (dependency.aggregator.isDefined) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") - false - } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + - s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") - false - } else { - log.debug(s"Can use UnsafeShuffle for shuffle $shufId") - true - } - } -} - -/** - * A shuffle implementation that uses directly-managed memory to implement several performance - * optimizations for certain types of shuffles. In cases where the new performance optimizations - * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those - * shuffles. - * - * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: - * - * - The shuffle dependency specifies no aggregation or output ordering. - * - The shuffle serializer supports relocation of serialized values (this is currently supported - * by KryoSerializer and Spark SQL's custom serializers). - * - The shuffle produces fewer than 16777216 output partitions. - * - No individual record is larger than 128 MB when serialized. - * - * In addition, extra spill-merging optimizations are automatically applied when the shuffle - * compression codec supports concatenation of serialized streams. This is currently supported by - * Spark's LZF serializer. - * - * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. - * In sort-based shuffle, incoming records are sorted according to their target partition ids, then - * written to a single map output file. Reducers fetch contiguous regions of this file in order to - * read their portion of the map output. In cases where the map output data is too large to fit in - * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged - * to produce the final output file. - * - * UnsafeShuffleManager optimizes this process in several ways: - * - * - Its sort operates on serialized binary data rather than Java objects, which reduces memory - * consumption and GC overheads. This optimization requires the record serializer to have certain - * properties to allow serialized records to be re-ordered without requiring deserialization. - * See SPARK-4550, where this optimization was first proposed and implemented, for more details. - * - * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts - * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per - * record in the sorting array, this fits more of the array into cache. - * - * - The spill merging procedure operates on blocks of serialized records that belong to the same - * partition and does not need to deserialize records during the merge. - * - * - When the spill compression codec supports concatenation of compressed data, the spill merge - * simply concatenates the serialized and compressed spill partitions to produce the final output - * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used - * and avoids the need to allocate decompression or copying buffers during the merge. - * - * For more details on UnsafeShuffleManager's design, see SPARK-7081. - */ -private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - if (!conf.getBoolean("spark.shuffle.spill", true)) { - logWarning( - "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + - "manager; its optimized shuffles will continue to spill to disk when necessary.") - } - - private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) - private[this] val shufflesThatFellBackToSortShuffle = - Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) - private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() - - /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. - */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { - new UnsafeShuffleHandle[K, V]( - shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { - sortShuffleManager.getReader(handle, startPartition, endPartition, context) - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { - handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] => - numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) - val env = SparkEnv.get - new UnsafeShuffleWriter( - env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], - context.taskMemoryManager(), - env.shuffleMemoryManager, - unsafeShuffleHandle, - mapId, - context, - env.conf) - case other => - shufflesThatFellBackToSortShuffle.add(handle.shuffleId) - sortShuffleManager.getWriter(handle, mapId, context) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { - sortShuffleManager.unregisterShuffle(shuffleId) - } else { - Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => - (0 until numMaps).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - } - - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - sortShuffleManager.shuffleBlockResolver - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - sortShuffleManager.stop() - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala deleted file mode 100644 index ae60f3b0cb555..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.OutputStream - -import scala.collection.mutable.ArrayBuffer - -/** - * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The - * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts - * of memory and needing to copy the full contents. The disadvantage is that the contents don't - * occupy a contiguous segment of memory. - */ -private[spark] class ChainedBuffer(chunkSize: Int) { - - private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros( - java.lang.Long.highestOneBit(chunkSize)) - assert((1 << chunkSizeLog2) == chunkSize, - s"ChainedBuffer chunk size $chunkSize must be a power of two") - private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() - private var _size: Long = 0 - - /** - * Feed bytes from this buffer into a DiskBlockObjectWriter. - * - * @param pos Offset in the buffer to read from. - * @param os OutputStream to read into. - * @param len Number of bytes to read. - */ - def read(pos: Long, os: OutputStream, len: Int): Unit = { - if (pos + len > _size) { - throw new IndexOutOfBoundsException( - s"Read of $len bytes at position $pos would go past size ${_size} of buffer") - } - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toRead: Int = math.min(len - written, chunkSize - posInChunk) - os.write(chunks(chunkIndex), posInChunk, toRead) - written += toRead - chunkIndex += 1 - posInChunk = 0 - } - } - - /** - * Read bytes from this buffer into a byte array. - * - * @param pos Offset in the buffer to read from. - * @param bytes Byte array to read into. - * @param offs Offset in the byte array to read to. - * @param len Number of bytes to read. - */ - def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { - if (pos + len > _size) { - throw new IndexOutOfBoundsException( - s"Read of $len bytes at position $pos would go past size of buffer") - } - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toRead: Int = math.min(len - written, chunkSize - posInChunk) - System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) - written += toRead - chunkIndex += 1 - posInChunk = 0 - } - } - - /** - * Write bytes from a byte array into this buffer. - * - * @param pos Offset in the buffer to write to. - * @param bytes Byte array to write from. - * @param offs Offset in the byte array to write from. - * @param len Number of bytes to write. - */ - def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { - if (pos > _size) { - throw new IndexOutOfBoundsException( - s"Write at position $pos starts after end of buffer ${_size}") - } - // Grow if needed - val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt - while (endChunkIndex >= chunks.length) { - chunks += new Array[Byte](chunkSize) - } - - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toWrite: Int = math.min(len - written, chunkSize - posInChunk) - System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) - written += toWrite - chunkIndex += 1 - posInChunk = 0 - } - - _size = math.max(_size, pos + len) - } - - /** - * Total size of buffer that can be written to without allocating additional memory. - */ - def capacity: Long = chunks.size.toLong * chunkSize - - /** - * Size of the logical buffer. - */ - def size: Long = _size -} - -/** - * Output stream that writes to a ChainedBuffer. - */ -private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { - private var pos: Long = 0 - - override def write(b: Int): Unit = { - throw new UnsupportedOperationException() - } - - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - chainedBuffer.write(pos, bytes, offs, len) - pos += len - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 749be34d8e8fd..c48c453a90d01 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,7 +29,6 @@ import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -69,8 +68,8 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * At a high level, this class works internally as follows: * * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if - * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we - * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key. + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. * To avoid calling the partitioner multiple times with each key, we store the partition ID * alongside each record. * @@ -93,8 +92,7 @@ private[spark] class ExternalSorter[K, V, C]( ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) extends Logging - with Spillable[WritablePartitionedPairCollection[K, C]] - with SortShuffleFileWriter[K, V] { + with Spillable[WritablePartitionedPairCollection[K, C]] { private val conf = SparkEnv.get.conf @@ -104,13 +102,6 @@ private[spark] class ExternalSorter[K, V, C]( if (shouldPartition) partitioner.get.getPartition(key) else 0 } - // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class. - // As a sanity check, make sure that we're not handling a shuffle which should use that path. - if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) { - throw new IllegalArgumentException("ExternalSorter should not be used to handle " - + " a sort that the BypassMergeSortShuffleWriter should handle") - } - private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(serializer) @@ -128,23 +119,11 @@ private[spark] class ExternalSorter[K, V, C]( // grow internal data structures by growing + copying every time the number of objects doubles. private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) - private val useSerializedPairBuffer = - ordering.isEmpty && - conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && - ser.supportsRelocationOfSerializedObjects - private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB - private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = { - if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } - } // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. private var map = new PartitionedAppendOnlyMap[K, C] - private var buffer = newBuffer() + private var buffer = new PartitionedPairBuffer[K, C] // Total spilling statistics private var _diskBytesSpilled = 0L @@ -192,7 +171,7 @@ private[spark] class ExternalSorter[K, V, C]( */ private[spark] def numSpills: Int = spills.size - override def insertAll(records: Iterator[Product2[K, V]]): Unit = { + def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -236,7 +215,7 @@ private[spark] class ExternalSorter[K, V, C]( } else { estimatedSize = buffer.estimateSize() if (maybeSpill(buffer, estimatedSize)) { - buffer = newBuffer() + buffer = new PartitionedPairBuffer[K, C] } } @@ -659,7 +638,7 @@ private[spark] class ExternalSorter[K, V, C]( * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - override def writePartitionedFile( + def writePartitionedFile( blockId: BlockId, context: TaskContext, outputFile: File): Array[Long] = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala deleted file mode 100644 index 87a786b02d651..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.InputStream -import java.nio.IntBuffer -import java.util.Comparator - -import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.DiskBlockObjectWriter -import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ - -/** - * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes - * its records upon insert and stores them as raw bytes. - * - * We use two data-structures to store the contents. The serialized records are stored in a - * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a - * metadata buffer that stores pointers into the data buffer as well as the partition ID of each - * record. Each entry in the metadata buffer takes up a fixed amount of space. - * - * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not - * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can - * happen without following any pointers, which should minimize cache misses. - * - * Currently, only sorting by partition is supported. - * - * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across - * two integers: - * - * +-------------+------------+------------+-------------+ - * | keyStart | keyValLen | partitionId | - * +-------------+------------+------------+-------------+ - * - * The buffer can support up to `536870911 (2 ^ 29 - 1)` records. - * - * @param metaInitialRecords The initial number of entries in the metadata buffer. - * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. - * @param serializerInstance the serializer used for serializing inserted records. - */ -private[spark] class PartitionedSerializedPairBuffer[K, V]( - metaInitialRecords: Int, - kvBlockSize: Int, - serializerInstance: SerializerInstance) - extends WritablePartitionedPairCollection[K, V] with SizeTracker { - - if (serializerInstance.isInstanceOf[JavaSerializerInstance]) { - throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" + - " Java-serialized objects.") - } - - require(metaInitialRecords <= MAXIMUM_RECORDS, - s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records") - private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE) - - private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize) - private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer) - private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream) - - def insert(partition: Int, key: K, value: V): Unit = { - if (metaBuffer.position == metaBuffer.capacity) { - growMetaBuffer() - } - - val keyStart = kvBuffer.size - kvSerializationStream.writeKey[Any](key) - kvSerializationStream.writeValue[Any](value) - kvSerializationStream.flush() - val keyValLen = (kvBuffer.size - keyStart).toInt - - // keyStart, a long, gets split across two ints - metaBuffer.put(keyStart.toInt) - metaBuffer.put((keyStart >> 32).toInt) - metaBuffer.put(keyValLen) - metaBuffer.put(partition) - } - - /** Double the size of the array because we've reached capacity */ - private def growMetaBuffer(): Unit = { - if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) { - throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records") - } - val newCapacity = - if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) { - // Overflow - MAXIMUM_META_BUFFER_CAPACITY - } else { - metaBuffer.capacity * 2 - } - val newMetaBuffer = IntBuffer.allocate(newCapacity) - newMetaBuffer.put(metaBuffer.array) - metaBuffer = newMetaBuffer - } - - /** Iterate through the data in a given order. For this class this is not really destructive. */ - override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) - : Iterator[((Int, K), V)] = { - sort(keyComparator) - val is = orderedInputStream - val deserStream = serializerInstance.deserializeStream(is) - new Iterator[((Int, K), V)] { - var metaBufferPos = 0 - def hasNext: Boolean = metaBufferPos < metaBuffer.position - def next(): ((Int, K), V) = { - val key = deserStream.readKey[Any]().asInstanceOf[K] - val value = deserStream.readValue[Any]().asInstanceOf[V] - val partition = metaBuffer.get(metaBufferPos + PARTITION) - metaBufferPos += RECORD_SIZE - ((partition, key), value) - } - } - } - - override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity - - override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) - : WritablePartitionedIterator = { - sort(keyComparator) - new WritablePartitionedIterator { - // current position in the meta buffer in ints - var pos = 0 - - def writeNext(writer: DiskBlockObjectWriter): Unit = { - val keyStart = getKeyStartPos(metaBuffer, pos) - val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) - pos += RECORD_SIZE - kvBuffer.read(keyStart, writer, keyValLen) - writer.recordWritten() - } - def nextPartition(): Int = metaBuffer.get(pos + PARTITION) - def hasNext(): Boolean = pos < metaBuffer.position - } - } - - // Visible for testing - def orderedInputStream: OrderedInputStream = { - new OrderedInputStream(metaBuffer, kvBuffer) - } - - private def sort(keyComparator: Option[Comparator[K]]): Unit = { - val comparator = if (keyComparator.isEmpty) { - new Comparator[Int]() { - def compare(partition1: Int, partition2: Int): Int = { - partition1 - partition2 - } - } - } else { - throw new UnsupportedOperationException() - } - - val sorter = new Sorter(new SerializedSortDataFormat) - sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator) - } -} - -private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) - extends InputStream { - - import PartitionedSerializedPairBuffer._ - - private var metaBufferPos = 0 - private var kvBufferPos = - if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0 - - override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) - - override def read(bytes: Array[Byte], offs: Int, len: Int): Int = { - if (metaBufferPos >= metaBuffer.position) { - return -1 - } - val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) - - (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt - val toRead = math.min(bytesRemainingInRecord, len) - kvBuffer.read(kvBufferPos, bytes, offs, toRead) - if (toRead == bytesRemainingInRecord) { - metaBufferPos += RECORD_SIZE - if (metaBufferPos < metaBuffer.position) { - kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos) - } - } else { - kvBufferPos += toRead - } - toRead - } - - override def read(): Int = { - throw new UnsupportedOperationException() - } -} - -private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] { - - private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE) - - /** Return the sort key for the element at the given index. */ - override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = { - metaBuffer.get(pos * RECORD_SIZE + PARTITION) - } - - /** Swap two elements. */ - override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = { - val iOff = pos0 * RECORD_SIZE - val jOff = pos1 * RECORD_SIZE - System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE) - System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE) - System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE) - } - - /** Copy a single element from src(srcPos) to dst(dstPos). */ - override def copyElement( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE) - } - - /** - * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. - * Overlapping ranges are allowed. - */ - override def copyRange( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int, - length: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length) - } - - /** - * Allocates a Buffer that can hold up to 'length' elements. - * All elements of the buffer should be considered invalid until data is explicitly copied in. - */ - override def allocate(length: Int): IntBuffer = { - IntBuffer.allocate(length * RECORD_SIZE) - } -} - -private object PartitionedSerializedPairBuffer { - val KEY_START = 0 // keyStart, a long, gets split across two ints - val KEY_VAL_LEN = 2 - val PARTITION = 3 - val RECORD_SIZE = PARTITION + 1 // num ints of metadata - - val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1 - val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4 - - def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { - val lower32 = metaBuffer.get(metaBufferPos + KEY_START) - val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) - (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL) - } -} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java similarity index 96% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 934b7e03050b6..232ae4d926bcd 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -15,8 +15,9 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; +import org.apache.spark.shuffle.sort.PackedRecordPointer; import org.junit.Test; import static org.junit.Assert.*; @@ -24,7 +25,7 @@ import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*; +import static org.apache.spark.shuffle.sort.PackedRecordPointer.*; public class PackedRecordPointerSuite { diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java similarity index 87% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 40fefe2c9d140..1ef3c5ff64bac 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Arrays; import java.util.Random; @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -public class UnsafeShuffleInMemorySorterSuite { +public class ShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; @@ -40,8 +40,8 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100); - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -62,7 +62,7 @@ public void testBasicSorting() throws Exception { new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); final MemoryBlock dataPage = memoryManager.allocatePage(2048); final Object baseObject = dataPage.getBaseObject(); - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter @@ -79,7 +79,7 @@ public void testBasicSorting() throws Exception { } // Sort the records - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); int prevPartitionId = -1; Arrays.sort(dataToSort); for (int i = 0; i < dataToSort.length; i++) { @@ -103,7 +103,7 @@ public void testBasicSorting() throws Exception { @Test public void testSortingManyNumbers() throws Exception { - UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { @@ -112,7 +112,7 @@ public void testSortingManyNumbers() throws Exception { } Arrays.sort(numbersToSort); int[] sorterResult = new int[numbersToSort.length]; - UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); int j = 0; while (iter.hasNext()) { iter.loadNext(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java similarity index 98% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index d218344cd4520..29d9823b1f71b 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.*; import java.nio.ByteBuffer; @@ -23,7 +23,6 @@ import scala.*; import scala.collection.Iterator; -import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; import com.google.common.collect.Iterators; @@ -56,6 +55,7 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import org.apache.spark.storage.*; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; @@ -204,7 +204,7 @@ private UnsafeShuffleWriter createWriter( shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle(0, 1, shuffleDep), + new SerializedShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, conf @@ -461,7 +461,7 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList>(); - final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; new Random(42).nextBytes(bytes); dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); writer.write(dataToWrite.iterator()); @@ -516,7 +516,7 @@ public void testPeakMemoryUsed() throws Exception { shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle<>(0, 1, shuffleDep), + new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf); diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 63358172ea1f4..b8ab227517cc4 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -17,13 +17,78 @@ package org.apache.spark +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. + private var tempDir: File = _ + override def beforeAll() { conf.set("spark.shuffle.manager", "sort") } + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + conf.set("spark.local.dir", tempDir.getAbsolutePath) + } + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") { + sc = new SparkContext("local", "test", conf) + // Create a shuffled RDD and verify that it actually uses the new serialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") { + sc = new SparkContext("local", "test", conf) + // Create a shuffled RDD and verify that it actually uses the old deserialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = { + def getAllFiles: Set[File] = + FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5b01ddb298c39..3816b8c4a09aa 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1062,10 +1062,10 @@ class DAGSchedulerSuite */ test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) val firstShuffleId = firstShuffleDep.shuffleId val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1175,7 +1175,7 @@ class DAGSchedulerSuite */ test("register map outputs correctly after ExecutorLost and task Resubmitted") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) submit(reduceRdd, Array(0)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 341f56df2dafc..b92a302806f76 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} -import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -42,25 +43,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ private var taskMetrics: TaskMetrics = _ - private var shuffleWriteMetrics: ShuffleWriteMetrics = _ private var tempDir: File = _ private var outputFile: File = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] - private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0) - private val serializer: Serializer = new JavaSerializer(conf) + private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) - shuffleWriteMetrics = new ShuffleWriteMetrics taskMetrics = new TaskMetrics - taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) MockitoAnnotations.initMocks(this) + shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( + shuffleId = 0, + numMaps = 2, + dependency = dependency + ) + when(dependency.partitioner).thenReturn(new HashPartitioner(7)) + when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf))) when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(blockManager.getDiskWriter( any[BlockId], @@ -107,18 +114,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) - writer.insertAll(Iterator.empty) - val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) - assert(partitionLengths.sum === 0) + writer.write(Iterator.empty) + writer.stop( /* success = */ true) + assert(writer.getPartitionLengths.sum === 0) assert(outputFile.exists()) assert(outputFile.length() === 0) assert(temporaryFilesCreated.isEmpty) + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === 0) assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) assert(taskMetrics.diskBytesSpilled === 0) @@ -129,17 +138,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) - writer.insertAll(records) + writer.write(records) + writer.stop( /* success = */ true) assert(temporaryFilesCreated.nonEmpty) - val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) - assert(partitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.sum === outputFile.length()) assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) assert(taskMetrics.diskBytesSpilled === 0) @@ -148,14 +159,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) intercept[SparkException] { - writer.insertAll((0 until 100000).iterator.map(i => { + writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { throw new SparkException("Intentional failure") } @@ -163,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte })) } assert(temporaryFilesCreated.nonEmpty) - writer.stop() + writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala similarity index 80% rename from core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 6727934d8c7ca..8744a072cb3f6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe +package org.apache.spark.shuffle.sort import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -29,9 +29,9 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are * performed in other suites. */ -class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { +class SortShuffleManagerSuite extends SparkFunSuite with Matchers { - import UnsafeShuffleManager.canUseUnsafeShuffle + import SortShuffleManager.canUseSerializedShuffle private class RuntimeExceptionAnswer extends Answer[Object] { override def answer(invocation: InvocationOnMock): Object = { @@ -55,10 +55,10 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { dep } - test("supported shuffle dependencies") { + test("supported shuffle dependencies for serialized shuffle") { val kryo = Some(new KryoSerializer(new SparkConf())) - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = None, @@ -68,7 +68,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]]) when(rangePartitioner.numPartitions).thenReturn(2) - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = rangePartitioner, serializer = kryo, keyOrdering = None, @@ -77,7 +77,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // Shuffles with key orderings are supported as long as no aggregator is specified - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = Some(mock(classOf[Ordering[Any]])), @@ -87,12 +87,12 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { } - test("unsupported shuffle dependencies") { + test("unsupported shuffle dependencies for serialized shuffle") { val kryo = Some(new KryoSerializer(new SparkConf())) val java = Some(new JavaSerializer(new SparkConf())) // We only support serializers that support object relocation - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = java, keyOrdering = None, @@ -100,9 +100,11 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles with more than 16 million output partitions - assert(!canUseUnsafeShuffle(shuffleDep( - partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), + // The serialized shuffle path do not support shuffles with more than 16 million output + // partitions, due to a limitation in its sorter implementation. + assert(!canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner( + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1), serializer = kryo, keyOrdering = None, aggregator = None, @@ -110,14 +112,14 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // We do not support shuffles that perform aggregation - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = None, aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), mapSideCombine = false ))) - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = Some(mock(classOf[Ordering[Any]])), diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala deleted file mode 100644 index 34b4984f12c09..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort - -import org.mockito.Mockito._ - -import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite} - -class SortShuffleWriterSuite extends SparkFunSuite { - - import SortShuffleWriter._ - - test("conditions for bypassing merge-sort") { - val conf = new SparkConf(loadDefaults = false) - val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS) - val ord = implicitly[Ordering[Int]] - - // Numbers of partitions that are above and below the default bypassMergeThreshold - val FEW_PARTITIONS = 50 - val MANY_PARTITIONS = 10000 - - // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high - assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None)) - assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None)) - - // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions - assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord))) - assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None)) - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala deleted file mode 100644 index 259020a2ddc34..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.io.File - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.FileUtils -import org.apache.commons.io.filefilter.TrueFileFilter -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite} -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.Utils - -class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { - - // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. - - override def beforeAll() { - conf.set("spark.shuffle.manager", "tungsten-sort") - } - - test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new KryoSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) - } - } - - test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the old SortShuffle path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new JavaSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala deleted file mode 100644 index 05306f408847d..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.nio.ByteBuffer - -import org.scalatest.Matchers._ - -import org.apache.spark.SparkFunSuite - -class ChainedBufferSuite extends SparkFunSuite { - test("write and read at start") { - // write from start of source array - val buffer = new ChainedBuffer(8) - buffer.capacity should be (0) - verifyWriteAndRead(buffer, 0, 0, 0, 4) - buffer.capacity should be (8) - - // write from middle of source array - verifyWriteAndRead(buffer, 0, 5, 0, 4) - buffer.capacity should be (8) - - // read to middle of target array - verifyWriteAndRead(buffer, 0, 0, 5, 4) - buffer.capacity should be (8) - - // write up to border - verifyWriteAndRead(buffer, 0, 0, 0, 8) - buffer.capacity should be (8) - - // expand into second buffer - verifyWriteAndRead(buffer, 0, 0, 0, 12) - buffer.capacity should be (16) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 0, 0, 0, 28) - buffer.capacity should be (32) - } - - test("write and read at middle") { - val buffer = new ChainedBuffer(8) - - // fill to a middle point - verifyWriteAndRead(buffer, 0, 0, 0, 3) - - // write from start of source array - verifyWriteAndRead(buffer, 3, 0, 0, 4) - buffer.capacity should be (8) - - // write from middle of source array - verifyWriteAndRead(buffer, 3, 5, 0, 4) - buffer.capacity should be (8) - - // read to middle of target array - verifyWriteAndRead(buffer, 3, 0, 5, 4) - buffer.capacity should be (8) - - // write up to border - verifyWriteAndRead(buffer, 3, 0, 0, 5) - buffer.capacity should be (8) - - // expand into second buffer - verifyWriteAndRead(buffer, 3, 0, 0, 12) - buffer.capacity should be (16) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 3, 0, 0, 28) - buffer.capacity should be (32) - } - - test("write and read at later buffer") { - val buffer = new ChainedBuffer(8) - - // fill to a middle point - verifyWriteAndRead(buffer, 0, 0, 0, 11) - - // write from start of source array - verifyWriteAndRead(buffer, 11, 0, 0, 4) - buffer.capacity should be (16) - - // write from middle of source array - verifyWriteAndRead(buffer, 11, 5, 0, 4) - buffer.capacity should be (16) - - // read to middle of target array - verifyWriteAndRead(buffer, 11, 0, 5, 4) - buffer.capacity should be (16) - - // write up to border - verifyWriteAndRead(buffer, 11, 0, 0, 5) - buffer.capacity should be (16) - - // expand into second buffer - verifyWriteAndRead(buffer, 11, 0, 0, 12) - buffer.capacity should be (24) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 11, 0, 0, 28) - buffer.capacity should be (40) - } - - - // Used to make sure we're writing different bytes each time - var rangeStart = 0 - - /** - * @param buffer The buffer to write to and read from. - * @param offsetInBuffer The offset to write to in the buffer. - * @param offsetInSource The offset in the array that the bytes are written from. - * @param offsetInTarget The offset in the array to read the bytes into. - * @param length The number of bytes to read and write - */ - def verifyWriteAndRead( - buffer: ChainedBuffer, - offsetInBuffer: Int, - offsetInSource: Int, - offsetInTarget: Int, - length: Int): Unit = { - val source = new Array[Byte](offsetInSource + length) - (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource) - buffer.write(offsetInBuffer, source, offsetInSource, length) - val target = new Array[Byte](offsetInTarget + length) - buffer.read(offsetInBuffer, target, offsetInTarget, length) - ByteBuffer.wrap(source, offsetInSource, length) should be - (ByteBuffer.wrap(target, offsetInTarget, length)) - - rangeStart += 100 - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala deleted file mode 100644 index 3b67f6206495a..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import com.google.common.io.ByteStreams - -import org.mockito.Matchers.any -import org.mockito.Mockito._ -import org.mockito.Mockito.RETURNS_SMART_NULLS -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.Matchers._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.DiskBlockObjectWriter - -class PartitionedSerializedPairBufferSuite extends SparkFunSuite { - test("OrderedInputStream single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - - val bytes = ByteStreams.toByteArray(buffer.orderedInputStream) - - val baos = new ByteArrayOutputStream() - val stream = serializerInstance.serializeStream(baos) - stream.writeObject(10) - stream.writeObject(struct) - stream.close() - - baos.toByteArray should be (bytes) - } - - test("insert single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - val elements = buffer.partitionedDestructiveSortedIterator(None).toArray - elements.size should be (1) - elements.head should be (((4, 10), struct)) - } - - test("insert multiple records") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct1 = SomeStruct("something1", 8) - buffer.insert(6, 1, struct1) - val struct2 = SomeStruct("something2", 9) - buffer.insert(4, 2, struct2) - val struct3 = SomeStruct("something3", 10) - buffer.insert(5, 3, struct3) - - val elements = buffer.partitionedDestructiveSortedIterator(None).toArray - elements.size should be (3) - elements(0) should be (((4, 2), struct2)) - elements(1) should be (((5, 3), struct3)) - elements(2) should be (((6, 1), struct1)) - } - - test("write single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val (writer, baos) = createMockWriter() - assert(it.hasNext) - it.nextPartition should be (4) - it.writeNext(writer) - assert(!it.hasNext) - - val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) - stream.readObject[AnyRef]() should be (10) - stream.readObject[AnyRef]() should be (struct) - } - - test("write multiple records") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct1 = SomeStruct("something1", 8) - buffer.insert(6, 1, struct1) - val struct2 = SomeStruct("something2", 9) - buffer.insert(4, 2, struct2) - val struct3 = SomeStruct("something3", 10) - buffer.insert(5, 3, struct3) - - val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val (writer, baos) = createMockWriter() - assert(it.hasNext) - it.nextPartition should be (4) - it.writeNext(writer) - assert(it.hasNext) - it.nextPartition should be (5) - it.writeNext(writer) - assert(it.hasNext) - it.nextPartition should be (6) - it.writeNext(writer) - assert(!it.hasNext) - - val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) - val iter = stream.asIterator - iter.next() should be (2) - iter.next() should be (struct2) - iter.next() should be (3) - iter.next() should be (struct3) - iter.next() should be (1) - iter.next() should be (struct1) - assert(!iter.hasNext) - } - - def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { - val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) - val baos = new ByteArrayOutputStream() - when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - val args = invocationOnMock.getArguments - val bytes = args(0).asInstanceOf[Array[Byte]] - val offset = args(1).asInstanceOf[Int] - val length = args(2).asInstanceOf[Int] - baos.write(bytes, offset, length) - } - }) - (writer, baos) - } -} - -case class SomeStruct(str: String, num: Int) diff --git a/docs/configuration.md b/docs/configuration.md index 46d92ceb762d6..be9c36bdfe3de 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -437,12 +437,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.manager sort - Implementation to use for shuffling data. There are three implementations available: - sort, hash and the new (1.5+) tungsten-sort. + Implementation to use for shuffling data. There are two implementations available: + sort and hash. Sort-based shuffle is more memory-efficient and is the default option starting in 1.2. - Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly - implementation with a fall back to regular sort based shuffle if its requirements are not - met. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0872d3f3e7093..b5e661d3ecfa8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,6 +37,7 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("network"), + MimaBuild.excludeSparkPackage("unsafe"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in 1.3. excludePackage("org.spark-project.jetty"), @@ -44,7 +45,11 @@ object MimaExcludes { // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), // SQL columnar is considered private. - excludePackage("org.apache.spark.sql.columnar") + excludePackage("org.apache.spark.sql.columnar"), + // The shuffle package is considered private. + excludePackage("org.apache.spark.shuffle"), + // The collections utlities are considered pricate. + excludePackage("org.apache.spark.util.collection") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ @@ -750,4 +755,4 @@ object MimaExcludes { MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") case _ => Seq() } -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 1d3379a5e2d91..7f60c8f5eaa95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree @@ -87,10 +86,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // fewer partitions (like RangePartitioner, for example). val conf = child.sqlContext.sparkContext.conf val shuffleManager = SparkEnv.get.shuffleManager - val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] || - shuffleManager.isInstanceOf[UnsafeShuffleManager] + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) if (sortBasedShuffleOn) { val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { @@ -99,22 +96,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { - // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting - // them. This optimization is guarded by a feature-flag and is only applied in cases where - // shuffle dependency does not specify an aggregator or ordering and the record serializer - // has certain properties. If this optimization is enabled, we can safely avoid the copy. + } else if (serializer.supportsRelocationOfSerializedObjects) { + // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records + // prior to sorting them. This optimization is only applied in cases where shuffle + // dependency does not specify an aggregator or ordering and the record serializer has + // certain properties. If this optimization is enabled, we can safely avoid the copy. // // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only // need to check whether the optimization is enabled and supported by our serializer. - // - // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false } else { - // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code - // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls - // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In - // both cases, we must copy. + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must + // copy. true } } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 75d1fced594c4..1680d7e0a85ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -101,7 +101,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten Utils.tryWithSafeFinally { val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1024") + .set("spark.shuffle.spill.initialMemoryThreshold", "1") .set("spark.shuffle.sort.bypassMergeThreshold", "0") .set("spark.testing.memory", "80000") @@ -109,7 +109,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") // prepare data val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 1000).iterator.map { i => + val data = (1 to 10000).iterator.map { i => (i, converter(Row(i))) } val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( @@ -141,9 +141,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } } - test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") { - val conf = new SparkConf() - .set("spark.shuffle.manager", "tungsten-sort") + test("SPARK-10403: unsafe row serializer with SortShuffleManager") { + val conf = new SparkConf().set("spark.shuffle.manager", "sort") sc = new SparkContext("local", "test", conf) val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))