Skip to content

Commit

Permalink
[SPARK-34840][SHUFFLE] Fixes cases of corruption in merged shuffle …
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR fixes bugs that causes corruption of push-merged blocks when a client terminates while pushing block. `RemoteBlockPushResolver` was introduced in #30062 (SPARK-32916).

There are 2 scenarios where the merged blocks get corrupted:
1. `StreamCallback.onFailure()` is called more than once. Initially we assumed that the onFailure callback will be called just once per stream. However, we observed that this is called twice when a client connection is reset. When the client connection is reset then there are 2 events that get triggered in this order.
 - `exceptionCaught`. This event is propagated to `StreamInterceptor`. `StreamInterceptor.exceptionCaught()` invokes `callback.onFailure(streamId, cause)`. This is the first time StreamCallback.onFailure() will be invoked.
 - `channelInactive`. Since the channel closes, the `channelInactive` event gets triggered which again is propagated to `StreamInterceptor`. `StreamInterceptor.channelInactive()` invokes `callback.onFailure(streamId, new ClosedChannelException())`. This is the second time  StreamCallback.onFailure() will be invoked.

2. The flag `isWriting` is set prematurely to true. This introduces an edge case where a stream that is trying to merge a duplicate block (created because of a speculative task) may interfere with an active stream if the duplicate stream fails.

Also adding additional changes that improve the code.

1.  Using positional writes all the time because this simplifies the code and with microbenchmarking haven't seen any performance impact.
2. Additional minor changes suggested by mridulm during an internal review.

### Why are the changes needed?
These are bug fixes and simplify the code.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added unit tests. I have also tested these changes in Linkedin's internal fork on a cluster.

Co-authored-by: Chandni Singh chsinghlinkedin.com
Co-authored-by: Min Shen mshenlinkedin.com

Closes #31934 from otterc/SPARK-32916-followup.

Lead-authored-by: Chandni Singh <singh.chandni@gmail.com>
Co-authored-by: Min Shen <mshen@linkedin.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
2 people authored and wakun committed Jul 30, 2022
1 parent c0e93b3 commit b6cace7
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 43 deletions.
Expand Up @@ -30,7 +30,6 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
Expand Down Expand Up @@ -445,9 +444,9 @@ static class PushBlockStreamCallback implements StreamCallbackWithID {
private final AppShufflePartitionInfo partitionInfo;
private int length = 0;
// This indicates that this stream got the opportunity to write the blocks to the merged file.
// Once this is set to true and the stream encounters a failure then it will take necessary
// action to overwrite any partial written data. This is reset to false when the stream
// completes without any failures.
// Once this is set to true and the stream encounters a failure then it will unset the
// currentMapId of the partition so that another stream can start merging the blocks to the
// partition. This is reset to false when the stream completes.
private boolean isWriting = false;
// Use on-heap instead of direct ByteBuffer since these buffers will be GC'ed very quickly
private List<ByteBuffer> deferredBufs;
Expand Down Expand Up @@ -477,16 +476,11 @@ public String getID() {
*/
private void writeBuf(ByteBuffer buf) throws IOException {
while (buf.hasRemaining()) {
if (partitionInfo.isEncounteredFailure()) {
long updatedPos = partitionInfo.getDataFilePos() + length;
logger.debug(
"{} shuffleId {} reduceId {} encountered failure current pos {} updated pos {}",
partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId,
partitionInfo.reduceId, partitionInfo.getDataFilePos(), updatedPos);
length += partitionInfo.dataChannel.write(buf, updatedPos);
} else {
length += partitionInfo.dataChannel.write(buf);
}
long updatedPos = partitionInfo.getDataFilePos() + length;
logger.debug("{} shuffleId {} reduceId {} current pos {} updated pos {}",
partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId,
partitionInfo.reduceId, partitionInfo.getDataFilePos(), updatedPos);
length += partitionInfo.dataChannel.write(buf, updatedPos);
}
}

Expand Down Expand Up @@ -581,7 +575,6 @@ public void onData(String streamId, ByteBuffer buf) throws IOException {
}
// Check whether we can write to disk
if (allowedToWrite()) {
isWriting = true;
// Identify duplicate block generated by speculative tasks. We respond success to
// the client in cases of duplicate even though no data is written.
if (isDuplicateBlock()) {
Expand All @@ -598,6 +591,7 @@ public void onData(String streamId, ByteBuffer buf) throws IOException {

// If we got here, it's safe to write the block data to the merged shuffle file. We
// first write any deferred block.
isWriting = true;
try {
if (deferredBufs != null && !deferredBufs.isEmpty()) {
writeDeferredBufs();
Expand All @@ -609,16 +603,6 @@ public void onData(String streamId, ByteBuffer buf) throws IOException {
// back to the client so the block could be retried.
throw ioe;
}
// If we got here, it means we successfully write the current chunk of block to merged
// shuffle file. If we encountered failure while writing the previous block, we should
// reset the file channel position and the status of partitionInfo to indicate that we
// have recovered from previous disk write failure. However, we do not update the
// position tracked by partitionInfo here. That is only updated while the entire block
// is successfully written to merged shuffle file.
if (partitionInfo.isEncounteredFailure()) {
partitionInfo.dataChannel.position(partitionInfo.getDataFilePos() + length);
partitionInfo.setEncounteredFailure(false);
}
} else {
logger.trace("{} shuffleId {} reduceId {} onData deferred",
partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId,
Expand All @@ -639,7 +623,7 @@ public void onData(String streamId, ByteBuffer buf) throws IOException {
// written to disk due to this reason. We thus decide to optimize for server
// throughput and memory usage.
if (deferredBufs == null) {
deferredBufs = new LinkedList<>();
deferredBufs = new ArrayList<>();
}
// Write the buffer to the in-memory deferred cache. Since buf is a slice of a larger
// byte buffer, we cache only the relevant bytes not the entire large buffer to save
Expand Down Expand Up @@ -670,7 +654,6 @@ public void onComplete(String streamId) throws IOException {
}
// Check if we can commit this block
if (allowedToWrite()) {
isWriting = true;
// Identify duplicate block generated by speculative tasks. We respond success to
// the client in cases of duplicate even though no data is written.
if (isDuplicateBlock()) {
Expand All @@ -681,6 +664,7 @@ public void onComplete(String streamId) throws IOException {
try {
if (deferredBufs != null && !deferredBufs.isEmpty()) {
abortIfNecessary();
isWriting = true;
writeDeferredBufs();
}
} catch (IOException ioe) {
Expand Down Expand Up @@ -738,14 +722,14 @@ public void onFailure(String streamId, Throwable throwable) throws IOException {
Map<Integer, AppShufflePartitionInfo> shufflePartitions =
mergeManager.partitions.get(partitionInfo.appShuffleId);
if (shufflePartitions != null && shufflePartitions.containsKey(partitionInfo.reduceId)) {
logger.debug("{} shuffleId {} reduceId {} set encountered failure",
logger.debug("{} shuffleId {} reduceId {} encountered failure",
partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId,
partitionInfo.reduceId);
partitionInfo.setCurrentMapIndex(-1);
partitionInfo.setEncounteredFailure(true);
}
}
}
isWriting = false;
}

@VisibleForTesting
Expand Down Expand Up @@ -802,8 +786,6 @@ public static class AppShufflePartitionInfo {
public FileChannel dataChannel;
// Location offset of the last successfully merged block for this shuffle partition
private long dataFilePos;
// Indicating whether failure was encountered when merging the previous block
private boolean encounteredFailure;
// Track the map index whose block is being merged for this shuffle partition
private int currentMapIndex;
// Bitmap tracking which mapper's blocks have been merged for this shuffle partition
Expand Down Expand Up @@ -836,7 +818,6 @@ public static class AppShufflePartitionInfo {
// Writing 0 offset so that we can reuse ShuffleIndexInformation.getIndex()
updateChunkInfo(0L, -1);
this.dataFilePos = 0;
this.encounteredFailure = false;
this.mapTracker = new RoaringBitmap();
this.chunkTracker = new RoaringBitmap();
}
Expand All @@ -851,14 +832,6 @@ public void setDataFilePos(long dataFilePos) {
this.dataFilePos = dataFilePos;
}

boolean isEncounteredFailure() {
return encounteredFailure;
}

void setEncounteredFailure(boolean encounteredFailure) {
this.encounteredFailure = encounteredFailure;
}

int getCurrentMapIndex() {
return currentMapIndex;
}
Expand Down
Expand Up @@ -28,6 +28,7 @@
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadLocalRandom;

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -292,18 +293,32 @@ public void testTooLateArrival() throws IOException {
@Test
public void testIncompleteStreamsAreOverwritten() throws IOException {
registerExecutor(TEST_APP, prepareLocalDirs(localDirs));
byte[] expectedBytes = new byte[4];
ThreadLocalRandom.current().nextBytes(expectedBytes);

StreamCallbackWithID stream1 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0));
stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4]));
byte[] data = new byte[10];
ThreadLocalRandom.current().nextBytes(data);
stream1.onData(stream1.getID(), ByteBuffer.wrap(data));
// There is a failure
stream1.onFailure(stream1.getID(), new RuntimeException("forced error"));
StreamCallbackWithID stream2 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0));
stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5]));
ByteBuffer nextBuf= ByteBuffer.wrap(expectedBytes, 0, 2);
stream2.onData(stream2.getID(), nextBuf);
stream2.onComplete(stream2.getID());
StreamCallbackWithID stream3 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 2, 0, 0));
nextBuf = ByteBuffer.wrap(expectedBytes, 2, 2);
stream3.onData(stream3.getID(), nextBuf);
stream3.onComplete(stream3.getID());
pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0));
MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{5}, new int[][]{{1}});
validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{1, 2}});
FileSegmentManagedBuffer mb =
(FileSegmentManagedBuffer) pushResolver.getMergedBlockData(TEST_APP, 0, 0, 0);
assertArrayEquals(expectedBytes, mb.nioByteBuffer().array());
}

@Test (expected = RuntimeException.class)
Expand Down Expand Up @@ -740,6 +755,72 @@ public void testFailureWhileTruncatingFiles() throws IOException {
validateChunks(TEST_APP, 0, 1, meta, new int[]{5, 3}, new int[][]{{0},{1}});
}

@Test
public void testOnFailureInvokedMoreThanOncePerBlock() throws IOException {
StreamCallbackWithID stream1 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0));
stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
stream1.onFailure(stream1.getID(), new RuntimeException("forced error"));
StreamCallbackWithID stream2 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0));
stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5]));
// On failure on stream1 gets invoked again and should cause no interference
stream1.onFailure(stream1.getID(), new RuntimeException("2nd forced error"));
StreamCallbackWithID stream3 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 3, 0, 0));
// This should be deferred as stream 2 is still the active stream
stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2]));
// Stream 2 writes more and completes
stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4]));
stream2.onComplete(stream2.getID());
stream3.onComplete(stream3.getID());
pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0));
MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {9, 2}, new int[][] {{1},{3}});
removeApplication(TEST_APP);
}

@Test (expected = RuntimeException.class)
public void testFailureAfterDuplicateBlockDoesNotInterfereActiveStream() throws IOException {
StreamCallbackWithID stream1 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0));
StreamCallbackWithID stream1Duplicate =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0));
stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
stream1.onComplete(stream1.getID());
stream1Duplicate.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));

StreamCallbackWithID stream2 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0));
stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5]));
// Should not change the current map id of the reduce partition
stream1Duplicate.onFailure(stream2.getID(), new RuntimeException("forced error"));

StreamCallbackWithID stream3 =
pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 2, 0, 0));
// This should be deferred as stream 2 is still the active stream
stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2]));
RuntimeException failedEx = null;
try {
stream3.onComplete(stream3.getID());
} catch (RuntimeException re) {
assertEquals(
"Couldn't find an opportunity to write block shufflePush_0_2_0 to merged shuffle",
re.getMessage());
failedEx = re;
}
// Stream 2 writes more and completes
stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4]));
stream2.onComplete(stream2.getID());
pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0));
MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0);
validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {11}, new int[][] {{0, 1}});
removeApplication(TEST_APP);
if (failedEx != null) {
throw failedEx;
}
}

private void useTestFiles(boolean useTestIndexFile, boolean useTestMetaFile) throws IOException {
pushResolver = new RemoteBlockPushResolver(conf) {
@Override
Expand Down

0 comments on commit b6cace7

Please sign in to comment.