Skip to content

Commit

Permalink
[SPARK-10708] Consolidate sort shuffle implementations
Browse files Browse the repository at this point in the history
There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations.
  • Loading branch information
JoshRosen committed Oct 22, 2015
1 parent 94e2064 commit f6d06ad
Show file tree
Hide file tree
Showing 30 changed files with 456 additions and 1,317 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,30 @@
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import javax.annotation.Nullable;

import scala.None$;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;

Expand All @@ -62,7 +71,7 @@
* <p>
* There have been proposals to completely remove this code path; see SPARK-6026 for details.
*/
final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {

private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);

Expand All @@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
private final BlockManager blockManager;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
private final Serializer serializer;
private final IndexShuffleBlockResolver shuffleBlockResolver;

/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
@Nullable private MapStatus mapStatus;
private long[] partitionLengths;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
* and then call stop() with success = false if they get an exception, we want to make sure
* we don't try deleting files, etc twice.
*/
private boolean stopping = false;

public BypassMergeSortShuffleWriter(
SparkConf conf,
BlockManager blockManager,
Partitioner partitioner,
ShuffleWriteMetrics writeMetrics,
Serializer serializer) {
IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf conf) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
this.numPartitions = partitioner.numPartitions();
this.blockManager = blockManager;
this.partitioner = partitioner;
this.writeMetrics = writeMetrics;
this.serializer = serializer;
final ShuffleDependency<K, V, V> dep = handle.dependency();
this.mapId = mapId;
this.shuffleId = dep.shuffleId();
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = new ShuffleWriteMetrics();
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
this.serializer = Serializer.getSerializer(dep.serializer());
this.shuffleBlockResolver = shuffleBlockResolver;
}

@Override
public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
Expand Down Expand Up @@ -124,13 +154,24 @@ public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}

partitionLengths =
writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

@Override
public long[] writePartitionedFile(
BlockId blockId,
TaskContext context,
File outputFile) throws IOException {
@VisibleForTesting
long[] getPartitionLengths() {
return partitionLengths;
}

/**
* Concatenate all of the per-partition files into a single combined file.
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
private long[] writePartitionedFile(File outputFile) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
Expand Down Expand Up @@ -165,18 +206,33 @@ public long[] writePartitionedFile(
}

@Override
public void stop() throws IOException {
if (partitionWriters != null) {
try {
for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
File file = writer.revertPartialWritesAndClose();
if (!file.delete()) {
logger.error("Error while deleting file {}", file.getAbsolutePath());
public Option<MapStatus> stop(boolean success) {
if (stopping) {
return None$.empty();
} else {
stopping = true;
if (success) {
if (mapStatus == null) {
throw new IllegalStateException("Cannot call stop(true) without having called write()");
}
return Option.apply(mapStatus);
} else {
// The map task failed, so delete our output data.
if (partitionWriters != null) {
try {
for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
File file = writer.revertPartialWritesAndClose();
if (!file.delete()) {
logger.error("Error while deleting file {}", file.getAbsolutePath());
}
}
} finally {
partitionWriters = null;
}
}
} finally {
partitionWriters = null;
shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
return None$.empty();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;

/**
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;

import javax.annotation.Nullable;
import java.io.File;
Expand Down Expand Up @@ -48,7 +48,7 @@
* <p>
* Incoming records are appended to data pages. When all records have been inserted (or when the
* current thread's shuffle memory limit is reached), the in-memory records are sorted according to
* their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
* their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then
* written to a single output file (or multiple files, if we've spilled). The format of the output
* files is the same as the format of the final output file written by
* {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
Expand All @@ -59,9 +59,9 @@
* spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
* specialized merge procedure that avoids extra serialization/deserialization.
*/
final class UnsafeShuffleExternalSorter {
final class ShuffleExternalSorter {

private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);

@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
Expand All @@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter {
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
private long numRecordsInsertedSinceLastSpill = 0;

/** Force this sorter to spill when there are this many elements in memory. For testing only */
private final long numElementsForSpillThreshold;

/** The buffer size to use when writing spills using DiskBlockObjectWriter */
private final int fileBufferSizeBytes;
Expand All @@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter {
private long peakMemoryUsedBytes;

// These variables are reset after spilling:
@Nullable private UnsafeShuffleInMemorySorter inMemSorter;
@Nullable private ShuffleInMemorySorter inMemSorter;
@Nullable private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;

public UnsafeShuffleExternalSorter(
public ShuffleExternalSorter(
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
Expand All @@ -117,6 +121,8 @@ public UnsafeShuffleExternalSorter(
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.pageSizeBytes = (int) Math.min(
PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
this.maxRecordSizeBytes = pageSizeBytes - 4;
Expand All @@ -140,7 +146,8 @@ private void initializeForWriting() throws IOException {
throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
}

this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
this.inMemSorter = new ShuffleInMemorySorter(initialSize);
numRecordsInsertedSinceLastSpill = 0;
}

/**
Expand All @@ -166,7 +173,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
}

// This call performs the actual sort.
final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
inMemSorter.getSortedIterator();

// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
Expand Down Expand Up @@ -406,6 +413,10 @@ public void insertRecord(
int lengthInBytes,
int partitionId) throws IOException {

if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
spill();
}

growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int totalSpaceRequired = lengthInBytes + 4;
Expand Down Expand Up @@ -453,6 +464,7 @@ public void insertRecord(
recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
assert(inMemSorter != null);
inMemSorter.insertRecord(recordAddress, partitionId);
numRecordsInsertedSinceLastSpill += 1;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
* limitations under the License.
*/

package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;

import java.util.Comparator;

import org.apache.spark.util.collection.Sorter;

final class UnsafeShuffleInMemorySorter {
final class ShuffleInMemorySorter {

private final Sorter<PackedRecordPointer, long[]> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> {
Expand All @@ -44,10 +44,10 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
*/
private int pointerArrayInsertPosition = 0;

public UnsafeShuffleInMemorySorter(int initialSize) {
public ShuffleInMemorySorter(int initialSize) {
assert (initialSize > 0);
this.pointerArray = new long[initialSize];
this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
}

public void expandPointerArray() {
Expand Down Expand Up @@ -92,14 +92,14 @@ public void insertRecord(long recordPointer, int partitionId) {
/**
* An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
*/
public static final class UnsafeShuffleSorterIterator {
public static final class ShuffleSorterIterator {

private final long[] pointerArray;
private final int numRecords;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;

public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
this.numRecords = numRecords;
this.pointerArray = pointerArray;
}
Expand All @@ -117,8 +117,8 @@ public void loadNext() {
/**
* Return an iterator over record pointers in sorted order.
*/
public UnsafeShuffleSorterIterator getSortedIterator() {
public ShuffleSorterIterator getSortedIterator() {
sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
* limitations under the License.
*/

package org.apache.spark.shuffle.unsafe;
package org.apache.spark.shuffle.sort;

import org.apache.spark.util.collection.SortDataFormat;

final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {

public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();

private UnsafeShuffleSortDataFormat() { }
private ShuffleSortDataFormat() { }

@Override
public PackedRecordPointer getKey(long[] data, int pos) {
Expand Down
Loading

0 comments on commit f6d06ad

Please sign in to comment.