diff --git a/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java new file mode 100644 index 0000000000000..df71f7fb08181 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.corruption; + +public enum Cause { + DISK, NETWORK, UNKNOWN; +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index a6bdc13e93234..967bad3880022 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -33,6 +33,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.corruption.Cause; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.GetLocalDirsForExecutors; import org.apache.spark.network.shuffle.protocol.LocalDirsForExecutors; @@ -47,6 +48,15 @@ public abstract class BlockStoreClient implements Closeable { protected volatile TransportClientFactory clientFactory; protected String appId; + public Cause diagnoseCorruption( + String host, + int port, + String execId, + String blockId, + long checksum) { + return Cause.UNKNOWN; + } + /** * Fetch a sequence of blocks from a remote node asynchronously, * diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index 56c06e640acda..7e4d059a2d2fc 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -34,6 +34,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.corruption.Cause; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; @@ -82,6 +83,16 @@ public void init(String appId) { clientFactory = context.createClientFactory(bootstraps); } + @Override + public Cause diagnoseCorruption( + String host, + int port, + String execId, + String blockId, + long checksum) { + return Cause.UNKNOWN; + } + @Override public void fetchBlocks( String host, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 7f5058124988f..785281b426ca3 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -48,7 +48,8 @@ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11), - PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14); + PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14), + DIAGNOSE_CORRUPTION(15), CORRUPTION_CAUSE(16); private final byte id; @@ -82,6 +83,8 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 12: return PushBlockStream.decode(buf); case 13: return FinalizeShuffleMerge.decode(buf); case 14: return MergeStatuses.decode(buf); + case 15: return DiagnoseCorruption.decode(buf); + case 16: return CorruptionCause.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java new file mode 100644 index 0000000000000..1dd6c3411da7a --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.corruption.Cause; + +/** Response to the {@link DiagnoseCorruption} */ +public class CorruptionCause extends BlockTransferMessage { + public Cause cause; + + public CorruptionCause(Cause cause) { + this.cause = cause; + } + + @Override + protected Type type() { + return Type.CORRUPTION_CAUSE; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("cause", cause) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CorruptionCause that = (CorruptionCause) o; + return cause == that.cause; + } + + @Override + public int hashCode() { + return cause.hashCode(); + } + + @Override + public int encodedLength() { + return 4; /* encoded length of cause */ + } + + @Override + public void encode(ByteBuf buf) { + buf.writeInt(cause.ordinal()); + } + + public static CorruptionCause decode(ByteBuf buf) { + int ordinal = buf.readInt(); + return new CorruptionCause(Cause.values()[ordinal]); + } + +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java new file mode 100644 index 0000000000000..1698768c2b4ef --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.protocol.Encoders; + +/** Request to get the cause of a corrupted block. Returns {@link CorruptionCause} */ +public class DiagnoseCorruption extends BlockTransferMessage { + private final String appId; + private final String execId; + public final String blockId; + public final long checksum; + + public DiagnoseCorruption(String appId, String execId, String blockId, long checksum) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.checksum = checksum; + } + + @Override + protected Type type() { + return Type.DIAGNOSE_CORRUPTION; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("blockId", blockId) + .append("checksum", checksum) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DiagnoseCorruption that = (DiagnoseCorruption) o; + + if (!appId.equals(that.appId)) return false; + if (!execId.equals(that.execId)) return false; + if (!blockId.equals(that.blockId)) return false; + return checksum == that.checksum; + } + + @Override + public int hashCode() { + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + blockId.hashCode(); + result = 31 * result + (int) checksum; + return result; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(blockId) + + 8; /* encoded length of checksum */ + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, blockId); + buf.writeLong(checksum); + } + + public static DiagnoseCorruption decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String blockId = Encoders.Strings.decode(buf); + long checksum = buf.readLong(); + return new DiagnoseCorruption(appId, execId, blockId, checksum); + } +} diff --git a/core/src/main/java/org/apache/spark/io/CountingWritableChannel.java b/core/src/main/java/org/apache/spark/io/CountingWritableChannel.java new file mode 100644 index 0000000000000..6db100b9d63b8 --- /dev/null +++ b/core/src/main/java/org/apache/spark/io/CountingWritableChannel.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.io; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * A wrapper for {@link java.nio.channels.WritableByteChannel} with the ability of counting + * written bytes. + */ +@Private +public class CountingWritableChannel implements WritableByteChannel { + + private WritableByteChannel delegate; + + private long count; + + public CountingWritableChannel(WritableByteChannel delegate) { + this.delegate = delegate; + this.count = 0; + } + + public long getCount() { + return this.count; + } + + @Override + public int write(ByteBuffer src) throws IOException { + int written = delegate.write(src); + if (written > 0) { + count += written; + } + return written; + } + + @Override + public boolean isOpen() { + return delegate.isOpen(); + } + + @Override + public void close() throws IOException { + delegate.close(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java index cad8dcfda52bc..ba3d5a603e052 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java @@ -32,5 +32,8 @@ public interface SingleSpillShuffleMapOutputWriter { /** * Transfer a file that contains the bytes of all the partitions written by this map task. */ - void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException; + void transferMapSpillFile( + File mapOutputFile, + long[] partitionLengths, + long[] checksums) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java index c5ded5e75a2d7..71ee1b78076d2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/metadata/MapOutputCommitMessage.java @@ -36,27 +36,39 @@ public final class MapOutputCommitMessage { private final long[] partitionLengths; + private final long[] partitionChecksums; private final Optional mapOutputMetadata; private MapOutputCommitMessage( - long[] partitionLengths, Optional mapOutputMetadata) { + long[] partitionLengths, + long[] partitionChecksums, + Optional mapOutputMetadata) { this.partitionLengths = partitionLengths; + this.partitionChecksums = partitionChecksums; this.mapOutputMetadata = mapOutputMetadata; } public static MapOutputCommitMessage of(long[] partitionLengths) { - return new MapOutputCommitMessage(partitionLengths, Optional.empty()); + return new MapOutputCommitMessage(partitionLengths, null, Optional.empty()); + } + + public static MapOutputCommitMessage of(long[] partitionLengths, long[] partitionChecksums) { + return new MapOutputCommitMessage(partitionLengths, partitionChecksums, Optional.empty()); } public static MapOutputCommitMessage of( long[] partitionLengths, MapOutputMetadata mapOutputMetadata) { - return new MapOutputCommitMessage(partitionLengths, Optional.of(mapOutputMetadata)); + return new MapOutputCommitMessage(partitionLengths, null, Optional.of(mapOutputMetadata)); } public long[] getPartitionLengths() { return partitionLengths; } + public long[] getPartitionChecksums() { + return partitionChecksums; + } + public Optional getMapOutputMetadata() { return mapOutputMetadata; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 3dbee1b13d287..b262ddaa8aab2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,6 +21,7 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.Optional; import javax.annotation.Nullable; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 833744f4777ce..731753fa31612 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.LinkedList; +import org.apache.spark.SparkEnv$; import scala.Tuple2; import com.google.common.annotations.VisibleForTesting; @@ -179,7 +180,14 @@ private void writeSortedFile(boolean isLastFile) { blockManager.diskBlockManager().createTempShuffleBlock(); final File file = spilledFileInfo._2(); final TempShuffleBlockId blockId = spilledFileInfo._1(); - final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + // If this's going to be the only spilled file, calculating the checksum for it when + // checksum enabled. Because in this case, we'll use `LocalDiskSingleSpillMapOutputWriter` + // to write the map output file, which simply renames this spill file to the map output file + // for efficiency. So it's only written once and here's our only chance to calculate the + // checksum for it. + final boolean checksumEnabled = (boolean) SparkEnv$.MODULE$.get().conf() + .get(package$.MODULE$.SHUFFLE_CHECKSUM()) && isLastFile && spills.isEmpty(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, checksumEnabled); // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. // Our write path doesn't actually use this serializer (since we end up calling the `write()` @@ -192,6 +200,10 @@ private void writeSortedFile(boolean isLastFile) { try (DiskBlockObjectWriter writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse)) { + if (checksumEnabled) { + writer.enableChecksum(); + } + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); while (sortedRecords.hasNext()) { sortedRecords.loadNext(); @@ -202,6 +214,9 @@ private void writeSortedFile(boolean isLastFile) { if (currentPartition != -1) { final FileSegment fileSegment = writer.commitAndGet(); spillInfo.partitionLengths[currentPartition] = fileSegment.length(); + if (checksumEnabled) { + spillInfo.partitionChecksums[currentPartition] = (long) fileSegment.checksum().get(); + } } currentPartition = partition; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java index 865def6b83c53..25b6d317840fd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java @@ -26,12 +26,12 @@ */ final class SpillInfo { final long[] partitionLengths; + final long[] partitionChecksums; final File file; - final TempShuffleBlockId blockId; - SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + SpillInfo(int numPartitions, File file, boolean checksumEnabled) { this.partitionLengths = new long[numPartitions]; + this.partitionChecksums = checksumEnabled ? new long[numPartitions] : new long[0]; this.file = file; - this.blockId = blockId; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index e8f94ba8ffeee..de18ac40e64b7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -277,7 +277,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { partitionLengths = spills[0].partitionLengths; logger.debug("Merge shuffle spills for mapId {} with length {}", mapId, partitionLengths.length); - maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths); + maybeSingleFileWriter.get() + .transferMapSpillFile(spills[0].file, partitionLengths, spills[0].partitionChecksums); } else { partitionLengths = mergeSpillsUsingStandardWriter(spills); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index 0b286264be43d..5120600b80b26 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -17,14 +17,14 @@ package org.apache.spark.shuffle.sort.io; -import java.io.BufferedOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import java.nio.channels.FileChannel; +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Optional; +import java.util.zip.Adler32; +import java.util.zip.CheckedOutputStream; +import java.util.zip.Checksum; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,6 +34,7 @@ import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.internal.config.package$; +import org.apache.spark.io.CountingWritableChannel; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage; import org.apache.spark.util.Utils; @@ -57,11 +58,15 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { private long currChannelPosition; private long bytesWrittenToMergedFile = 0L; + private Checksum checksumCal = null; + private long[] partitionChecksums = new long[0]; + private final File outputFile; private File outputTempFile; private FileOutputStream outputFileStream; - private FileChannel outputFileChannel; + private CountingWritableChannel outputChannel; private BufferedOutputStream outputBufferedFileStream; + private CheckedOutputStream checkedOutputStream; public LocalDiskShuffleMapOutputWriter( int shuffleId, @@ -76,6 +81,11 @@ public LocalDiskShuffleMapOutputWriter( (int) (long) sparkConf.get( package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; this.partitionLengths = new long[numPartitions]; + boolean checksumEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_CHECKSUM()); + if (checksumEnabled) { + this.checksumCal = new Adler32(); + this.partitionChecksums = new long[numPartitions]; + } this.outputFile = blockResolver.getDataFile(shuffleId, mapId); this.outputTempFile = null; } @@ -89,8 +99,8 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I if (outputTempFile == null) { outputTempFile = Utils.tempFileWith(outputFile); } - if (outputFileChannel != null) { - currChannelPosition = outputFileChannel.position(); + if (outputChannel != null) { + currChannelPosition = outputChannel.getCount(); } else { currChannelPosition = 0L; } @@ -100,12 +110,12 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I @Override public MapOutputCommitMessage commitAllPartitions() throws IOException { // Check the position after transferTo loop to see if it is in the right position and raise a - // exception if it is incorrect. The position will not be increased to the expected length + // exception if it is incorrect. The po sition will not be increased to the expected length // after calling transferTo in kernel version 2.6.32. This issue is described at // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. - if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) { + if (outputChannel != null && outputChannel.getCount() != bytesWrittenToMergedFile) { throw new IOException( - "Current position " + outputFileChannel.position() + " does not equal expected " + + "Current position " + outputChannel.getCount() + " does not equal expected " + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " + " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " + "to unexpected behavior when using transferTo. You can set " + @@ -115,8 +125,8 @@ public MapOutputCommitMessage commitAllPartitions() throws IOException { File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; log.debug("Writing shuffle index file for mapId {} with length {}", mapId, partitionLengths.length); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); - return MapOutputCommitMessage.of(partitionLengths); + blockResolver.writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, partitionChecksums, resolvedTmp); + return MapOutputCommitMessage.of(partitionLengths, partitionChecksums); } @Override @@ -131,28 +141,40 @@ private void cleanUp() throws IOException { if (outputBufferedFileStream != null) { outputBufferedFileStream.close(); } - if (outputFileChannel != null) { - outputFileChannel.close(); + if (outputChannel != null) { + outputChannel.close(); + } + if (checkedOutputStream != null) { + checkedOutputStream.close(); } if (outputFileStream != null) { outputFileStream.close(); } + if (checksumCal != null) { + checksumCal.reset(); + } } private void initStream() throws IOException { if (outputFileStream == null) { outputFileStream = new FileOutputStream(outputTempFile, true); } + if (checksumCal != null && checkedOutputStream == null) { + checkedOutputStream = new CheckedOutputStream(outputFileStream, checksumCal); + } if (outputBufferedFileStream == null) { - outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize); + outputBufferedFileStream = new BufferedOutputStream( + checksumCal != null ? checkedOutputStream : outputFileStream, bufferSize); } } private void initChannel() throws IOException { // This file needs to opened in append mode in order to work around a Linux kernel bug that // affects transferTo; see SPARK-3948 for more details. - if (outputFileChannel == null) { - outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel(); + if (outputChannel == null) { + FileOutputStream fileOut = new FileOutputStream(outputTempFile, true); + OutputStream out = checksumCal != null ? new CheckedOutputStream(fileOut, checksumCal) : fileOut; + outputChannel = new CountingWritableChannel(Channels.newChannel(out)); } } @@ -169,7 +191,7 @@ private LocalDiskShufflePartitionWriter(int partitionId) { @Override public OutputStream openStream() throws IOException { if (partStream == null) { - if (outputFileChannel != null) { + if (outputChannel != null) { throw new IllegalStateException("Requested an output channel for a previous write but" + " now an output stream has been requested. Should not be using both channels" + " and streams to write."); @@ -243,6 +265,10 @@ public void close() { isClosed = true; partitionLengths[partitionId] = count; bytesWrittenToMergedFile += count; + if (checksumCal != null) { + partitionChecksums[partitionId] = checksumCal.getValue(); + checksumCal.reset(); + } } private void verifyNotClosed() { @@ -261,19 +287,23 @@ private class PartitionWriterChannel implements WritableByteChannelWrapper { } public long getCount() throws IOException { - long writtenPosition = outputFileChannel.position(); + long writtenPosition = outputChannel.getCount(); return writtenPosition - currChannelPosition; } @Override public WritableByteChannel channel() { - return outputFileChannel; + return outputChannel; } @Override public void close() throws IOException { partitionLengths[partitionId] = getCount(); bytesWrittenToMergedFile += partitionLengths[partitionId]; + if (checksumCal != null) { + partitionChecksums[partitionId] = checksumCal.getValue(); + checksumCal.reset(); + } } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java index c8b41992a8919..52ef4342ef2b7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -44,12 +44,14 @@ public LocalDiskSingleSpillMapOutputWriter( @Override public void transferMapSpillFile( File mapSpillFile, - long[] partitionLengths) throws IOException { + long[] partitionLengths, + long[] partitionChecksums) throws IOException { // The map spill file already has the proper format, and it contains all of the partition data. // So just transfer it directly to the destination without any merging. File outputFile = blockResolver.getDataFile(shuffleId, mapId); File tempFile = Utils.tempFileWith(outputFile); Files.move(mapSpillFile.toPath(), tempFile.toPath()); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + blockResolver.writeMetadataFileAndCommit( + shuffleId, mapId, partitionLengths, partitionChecksums, tempFile); } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 1a18856e4156c..c49bf0117da90 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1359,6 +1359,14 @@ package object config { s"The buffer size must be greater than 0 and less than or equal to ${Int.MaxValue}.") .createWithDefault(4096) + private[spark] val SHUFFLE_CHECKSUM = + ConfigBuilder("spark.shuffle.checksum") + .doc("Whether to calculate the checksum of shuffle output. If enabled, Spark will try " + + "its best to tell if shuffle data corruption is caused by network or disk or others.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + private[spark] val SHUFFLE_COMPRESS = ConfigBuilder("spark.shuffle.compress") .doc("Whether to compress shuffle output. Compression will use " + diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index cafb39ea82ad9..e258545e1da02 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -22,11 +22,14 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID +import org.apache.spark.network.corruption.Cause import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] trait BlockDataManager { + def diagnoseShuffleBlockCorruption(blockId: BlockId, clientChecksum: Long): Cause + /** * Get the local directories that used by BlockManager to save the blocks to disk */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 5f831dc666ca5..c6194f5fc7d8b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -54,6 +54,11 @@ class NettyBlockRpcServer( logTrace(s"Received request: $message") message match { + case diagnose: DiagnoseCorruption => + val cause = blockManager + .diagnoseShuffleBlockCorruption(BlockId.apply(diagnose.blockId), diagnose.checksum) + responseContext.onSuccess(new CorruptionCause(cause).toByteBuffer) + case openBlocks: OpenBlocks => val blocksNum = openBlocks.blockIds.length val blocks = (0 until blocksNum).map { i => diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 828849812bbd1..ca630c80f036c 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -20,9 +20,11 @@ package org.apache.spark.network.netty import java.io.IOException import java.nio.ByteBuffer import java.util.{HashMap => JHashMap, Map => JMap} +import java.util.concurrent.TimeoutException import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} +import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.{Success, Try} @@ -31,15 +33,17 @@ import com.codahale.metrics.{Metric, MetricSet} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.ExecutorDeadException import org.apache.spark.internal.config +import org.apache.spark.internal.config.Network import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap} +import org.apache.spark.network.corruption.Cause import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher} -import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, CorruptionCause, DiagnoseCorruption, UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.storage.BlockManagerMessages.IsExecutorAlive @@ -104,6 +108,40 @@ private[spark] class NettyBlockTransferService( } } + override def diagnoseCorruption( + host: String, + port: Int, + execId: String, + blockId: String, + checksum: Long): Cause = { + // A monitor for the thread to wait on. + val result = Promise[Cause]() + val client = clientFactory.createClient(host, port) + client.sendRpc(new DiagnoseCorruption(appId, execId, blockId, checksum).toByteBuffer, + new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + val cause = BlockTransferMessage.Decoder + .fromByteBuffer(response).asInstanceOf[CorruptionCause] + result.success(cause.cause) + } + + override def onFailure(e: Throwable): Unit = { + logger.warn("Failed to get the corruption cause.", e) + result.success(Cause.UNKNOWN) + } + }) + val timeout = new RpcTimeout( + conf.get(Network.NETWORK_TIMEOUT).seconds, + Network.NETWORK_TIMEOUT.key) + try { + timeout.awaitResult(result.future) + } catch { + case _: TimeoutException => + logger.warn("Failed to get the corruption cause due to timeout.") + Cause.UNKNOWN + } + } + override def fetchBlocks( host: String, port: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a92d9fab6efc6..350320eef30c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1762,7 +1762,7 @@ private[spark] class DAGScheduler( } if (shouldAbortStage) { - val abortMessage = if (disallowStageRetryForTest) { + val abortMessage = if (false) { "Fetch failure will not retry stage due to testing config" } else { s"""$failedStage (${failedStage.name}) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d30b73a74e12e..a9ba05ecf5ce8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -22,6 +22,8 @@ import java.nio.ByteBuffer import java.nio.channels.Channels import java.nio.file.Files +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.NioBufferedFileInputStream @@ -110,6 +112,22 @@ private[spark] class IndexShuffleBlockResolver( .getOrElse(blockManager.diskBlockManager.getFile(blockId)) } + /** + * Get the shuffle checksum file. + * + * When the dirs parameter is None then use the disk manager's local directories. Otherwise, + * read from the specified directories. + */ + def getChecksumFile( + shuffleId: Int, + mapId: Long, + dirs: Option[Array[String]] = None): File = { + val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID) + dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .getOrElse(blockManager.diskBlockManager.getFile(blockId)) + } + /** * Remove data file and index file that contain the output data from one map. */ @@ -174,6 +192,26 @@ private[spark] class IndexShuffleBlockResolver( } } + private def getChecksums(checksumFile: File, blockNum: Int): Array[Long] = { + if (!checksumFile.exists()) return null + val checksums = new ArrayBuffer[Long] + // Read the checksums of blocks + var in: DataInputStream = null + try { + in = new DataInputStream(new NioBufferedFileInputStream(checksumFile)) + while (checksums.size < blockNum) { + checksums += in.readLong() + } + } catch { + case _: IOException | _: EOFException => + return null + } finally { + in.close() + } + + checksums.toArray + } + /** * Write a provided shuffle block as a stream. Used for block migrations. * ShuffleBlockBatchIds must contain the full range represented in the ShuffleIndexBlock. @@ -276,32 +314,55 @@ private[spark] class IndexShuffleBlockResolver( /** - * Write an index file with the offsets of each block, plus a final offset at the end for the - * end of the output file. This will be used by getBlockData to figure out where each block - * begins and ends. + * Commit the data and metadata files as an atomic operation, use the existing ones, or + * replace them with new ones. Note that the metadata parameters (`lengths`, `checksums`) + * will be updated to match the existing ones if use the existing ones. * - * It will commit the data and index file as an atomic operation, use the existing ones, or - * replace them with new ones. + * There're two kinds of metadata files: * - * Note: the `lengths` will be updated to match the existing index file if use the existing ones. + * - index file + * An index file contains the offsets of each block, plus a final offset at the end + * for the end of the output file. It will be used by [[getBlockData]] to figure out + * where each block begins and ends. + * + * - checksum file (optional) + * An checksum file contains the checksum of each block. It will be used to diagnose + * the cause when a block is corrupted. Note that empty `checksums` indicate that + * checksum is disabled. */ - def writeIndexFileAndCommit( + def writeMetadataFileAndCommit( shuffleId: Int, mapId: Long, lengths: Array[Long], + checksums: Array[Long], dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) + + val checksumEnabled = checksums.nonEmpty + val (checksumFileOpt, checksumTmpOpt) = if (checksumEnabled) { + assert(lengths.length == checksums.length, + "The size of partition lengths and checksums should be equal") + val checksumFile = getChecksumFile(shuffleId, mapId) + (Some(checksumFile), Some(Utils.tempFileWith(checksumFile))) + } else { + (None, None) + } + try { val dataFile = getDataFile(shuffleId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. this.synchronized { val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) - if (existingLengths != null) { + val existingChecksums = + checksumFileOpt.map(getChecksums(_, checksums.length)).getOrElse(checksums) + if (existingLengths != null && existingChecksums != null) { // Another attempt for the same task has already written our map outputs successfully, // so just use the existing partition lengths and delete our temporary map outputs. System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + System.arraycopy(existingChecksums, 0, checksums, 0, lengths.length) + if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } @@ -333,6 +394,28 @@ private[spark] class IndexShuffleBlockResolver( if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) } + + checksumTmpOpt.zip(checksumFileOpt).foreach { case (checksumTmp, checksumFile) => + val out = new DataOutputStream( + new BufferedOutputStream( + new FileOutputStream(checksumTmp) + ) + ) + Utils.tryWithSafeFinally { + checksums.foreach(out.writeLong) + } { + out.close() + } + + if (checksumFile.exists()) { + checksumFile.delete() + } + if (!checksumTmp.renameTo(checksumFile)) { + // It's not worthwhile to fail here after index file and data file are already + // successfully stored due to checksum is only used for the corner error case. + logWarning("fail to rename file " + checksumTmp + " to " + checksumFile) + } + } } } } finally { @@ -340,6 +423,11 @@ private[spark] class IndexShuffleBlockResolver( if (indexTmp.exists() && !indexTmp.delete()) { logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}") } + checksumTmpOpt.foreach { checksumTmp => + if (checksumTmp.exists() && !checksumTmp.delete()) { + logError(s"Failed to delete temporary checksum file at ${checksumTmp.getAbsolutePath}") + } + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 73bf809a08a68..6d43c099f261c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -81,6 +81,11 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } +@DeveloperApi +case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum" +} + @Since("3.2.0") @DeveloperApi case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 4c09e1615affb..ddb974c5e59cf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,8 +21,10 @@ import java.io._ import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} import java.nio.ByteBuffer import java.nio.channels.Channels +import java.nio.file.Files import java.util.Collections import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, TimeUnit} +import java.util.zip.{Adler32, CheckedInputStream} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -42,11 +44,13 @@ import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.internal.config.Network +import org.apache.spark.io.CountingWritableChannel import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.client.StreamCallbackWithID +import org.apache.spark.network.corruption.Cause import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -54,7 +58,7 @@ import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.{MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter} +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter} import org.apache.spark.storage.BlockManagerMessages.{DecommissionBlockManager, ReplicateBlock} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform @@ -275,6 +279,45 @@ private[spark] class BlockManager( override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString + override def diagnoseShuffleBlockCorruption(blockId: BlockId, clientChecksum: Long): Cause = { + assert(blockId.isInstanceOf[ShuffleBlockId], + s"Corruption diagnosis only supports shuffle block yet, but got $blockId") + val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] + val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] + val checksumFile = resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId) + val reduceId = shuffleBlock.reduceId + if (checksumFile.exists()) { + var in: DataInputStream = null + try { + val channel = Files.newByteChannel(checksumFile.toPath) + channel.position(reduceId * 8L) + in = new DataInputStream(Channels.newInputStream(channel)) + val goldenChecksum = in.readLong() + val blockData = resolver.getBlockData(blockId) + val checksumIn = new CheckedInputStream(blockData.createInputStream(), new Adler32) + val buffer = new Array[Byte](8192) + while (checksumIn.read(buffer, 0, 8192) != -1) {} + val recalculatedChecksum = checksumIn.getChecksum.getValue + if (goldenChecksum != recalculatedChecksum) { + Cause.DISK + } else if (goldenChecksum != clientChecksum) { + Cause.NETWORK + } else { + Cause.UNKNOWN + } + } catch { + case NonFatal(e) => + logWarning("Exception throws while diagnosing shuffle block corruption.", e) + Cause.UNKNOWN + } finally { + in.close() + } + } else { + // Even if checksum is enabled, a checksum file may not exist if error throws during writing. + Cause.UNKNOWN + } + } + /** * Abstraction for storing blocks from bytes, whether they start in memory or on disk. * diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index e55c09274cd9a..2d7ab5f86567c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.{ClosedByInterruptException, FileChannel} +import java.util.zip.{Adler32, CheckedOutputStream, Checksum} import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} @@ -76,6 +77,9 @@ private[spark] class DiskBlockObjectWriter( private var initialized = false private var streamOpen = false private var hasBeenClosed = false + private var checksumEnabled = false + private var checksumCal: Checksum = null + private var checksumOutputStream: CheckedOutputStream = null /** * Cursors used to represent positions in the file. @@ -101,12 +105,27 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsWritten = 0 + /** + * Enable the checksum calculation of this writer. It's invalid to call on this + * when the writer has already opened. + */ + def enableChecksum(): Unit = { + assert(!streamOpen && !initialized, + "Can't call enableChecksum() when the writer has already opened") + checksumEnabled = true + checksumCal = new Adler32() + } + private def initialize(): Unit = { fos = new FileOutputStream(file, true) channel = fos.getChannel() ts = new TimeTrackingOutputStream(writeMetrics, fos) + if (checksumEnabled) { + checksumOutputStream = new CheckedOutputStream(ts, checksumCal) + } class ManualCloseBufferedOutputStream - extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream + extends BufferedOutputStream(if (checksumEnabled) checksumOutputStream else ts, bufferSize) + with ManualCloseOutputStream mcs = new ManualCloseBufferedOutputStream } @@ -183,7 +202,14 @@ private[spark] class DiskBlockObjectWriter( } val pos = channel.position() - val fileSegment = new FileSegment(file, committedPosition, pos - committedPosition) + val checksum = if (checksumEnabled) { + val value = checksumCal.getValue + checksumCal.reset() + Some(value) + } else { + None + } + val fileSegment = new FileSegment(file, committedPosition, pos - committedPosition, checksum) committedPosition = pos // In certain compression codecs, more bytes are written after streams are closed writeMetrics.incBytesWritten(committedPosition - reportedPosition) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index fbda4912e15ad..29debd5ecd956 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,6 +31,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CountingWritableChannel import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils @@ -328,23 +329,3 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: override def deallocate(): Unit = source.close() } - -private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel { - - private var count = 0L - - def getCount: Long = count - - override def write(src: ByteBuffer): Int = { - val written = sink.write(src) - if (written > 0) { - count += written - } - written - } - - override def isOpen(): Boolean = sink.isOpen() - - override def close(): Unit = sink.close() - -} diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 021a9facfb0b2..b60466202a8b0 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -23,10 +23,15 @@ import java.io.File * References a particular segment of a file (potentially the entire file), * based off an offset and a length. */ -private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { +private[spark] class FileSegment( + val file: File, + val offset: Long, + val length: Long, + val checksum: Option[Long] = None) { require(offset >= 0, s"File segment offset cannot be negative (got $offset)") require(length >= 0, s"File segment length cannot be negative (got $length)") override def toString: String = { - "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + val checksumStr = checksum.map(c => s", checksum=$c").getOrElse("") + s"(name=${file.getName}, offset=$offset, length=$length$checksumStr)" } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f30437f404455..af76b873e8212 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import java.util.zip.{Adler32, CheckedInputStream} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -28,9 +29,10 @@ import scala.util.{Failure, Success} import org.apache.commons.io.IOUtils -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.internal.Logging +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.corruption.Cause import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} @@ -152,6 +154,8 @@ final class ShuffleBlockFetcherIterator( */ private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + private[this] val checksumEnabled = SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM) + /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. @@ -584,6 +588,8 @@ final class ShuffleBlockFetcherIterator( var result: FetchResult = null var input: InputStream = null + // it's only initialized when checksum enabled. + var checkedIn: CheckedInputStream = null var streamCompressedOrEncrypted: Boolean = false // Take the next fetched result and try to decompress it to detect data corruption, // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch @@ -636,7 +642,12 @@ final class ShuffleBlockFetcherIterator( } val in = try { - buf.createInputStream() + var bufIn = buf.createInputStream() + if (checksumEnabled) { + checkedIn = new CheckedInputStream(bufIn, new Adler32) + bufIn = checkedIn + } + bufIn } catch { // The exception could only be throwed by local shuffle block case e: IOException => @@ -663,16 +674,30 @@ final class ShuffleBlockFetcherIterator( } } catch { case e: IOException => - buf.release() if (buf.isInstanceOf[FileSegmentManagedBuffer] || corruptedBlocks.contains(blockId)) { + buf.release() throwFetchFailedException(blockId, mapIndex, address, e) } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest( - address, Array(FetchBlockInfo(blockId, size, mapIndex))) - result = null + logWarning(s"Got an corrupted block $blockId from $address", e) + // A disk issue indicates the data on disk has already corrupted, so it's + // meaningless to retry on this case. We'll give a retry in the case of + // network issue and other unknown issues (in order to keep the same + // behavior as previously) + val allowRetry = !checksumEnabled || + diagnoseCorruption(checkedIn, address, blockId) != Cause.DISK + buf.release() + if (allowRetry) { + logInfo(s"Will retry the block $blockId") + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } else { + logError(s"Block $blockId is corrupted due to disk issue, won't retry.") + throwFetchFailedException( + blockId, mapIndex, address, e, s"Block $blockId is corrupted due to disk issue") + } } } finally { // TODO: release the buf here to free memory earlier @@ -699,7 +724,48 @@ final class ShuffleBlockFetcherIterator( currentResult.blockId, currentResult.mapIndex, currentResult.address, - detectCorrupt && streamCompressedOrEncrypted)) + detectCorrupt && streamCompressedOrEncrypted, + Option(checkedIn))) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked + * when checksum is enabled. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the + * checksum of the block. Then, it will raise a synchronized RPC call along with the + * checksum to ask the server(where the corrupted block is fetched from) to diagnose the + * cause of corruption and return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address the address where the corrupted block is fetched from. + * @param blockId the blockId of the corrupted block. + * @return the cause of corruption, which should be one of the [[Cause]]. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): Cause = { + logInfo("Start corruption diagnosis.") + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](8192) + // consume the remaining data to calculate the checksum + try { + while (checkedIn.read(buffer, 0, 8192) != -1) {} + } catch { + case e: IOException => + logWarning("IOException throws while consuming the rest stream of the corrupted block", e) + return Cause.UNKNOWN + } + val checksum = checkedIn.getChecksum.getValue + val cause = shuffleClient.diagnoseCorruption( + address.host, address.port, address.executorId, blockId.toString, checksum) + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + logInfo(s"Finished corruption diagnosis in ${duration}ms, cause: $cause") + cause } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { @@ -766,12 +832,14 @@ final class ShuffleBlockFetcherIterator( blockId: BlockId, mapIndex: Int, address: BlockManagerId, - e: Throwable) = { + e: Throwable, + message: String = null) = { + val msg = Option(message).getOrElse(e.getMessage) blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => - throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e) + throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, msg, e) case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => - throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, e) + throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, msg, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) @@ -790,7 +858,8 @@ private class BufferReleasingInputStream( private val blockId: BlockId, private val mapIndex: Int, private val address: BlockManagerId, - private val detectCorruption: Boolean) + private val detectCorruption: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) extends InputStream { private[this] var closed = false @@ -832,8 +901,14 @@ private class BufferReleasingInputStream( block } catch { case e: IOException if detectCorruption => + val message = checkedInOpt.map { checkedIn => + val cause = iterator.diagnoseCorruption(checkedIn, address, blockId) + s"Block $blockId is corrupted due to $cause issue" + }.orNull IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, mapIndex, address, e) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, message) } } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 5666bb3e5f140..ea7e9d769594e 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -152,11 +152,13 @@ public void setUp() throws Exception { doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class)); + .writeMetadataFileAndCommit( + anyInt(), anyLong(), any(long[].class), any(long[].class), any(File.class)); doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null)); + .writeMetadataFileAndCommit( + anyInt(), anyLong(), any(long[].class), any(long[].class), eq(null)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 126faec334e77..6cbb4f1c0de37 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileOutputStream} +import java.nio.ByteBuffer import java.util.{Locale, Properties} import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService } @@ -302,6 +303,38 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC rdd.count() } + test("SPARK-18188: shuffle checksum detect disk corruption") { + conf.set(config.SHUFFLE_CHECKSUM, true) + sc = new SparkContext("local-cluster[2, 1, 2048]", "test", conf) + val rdd = sc.parallelize(1 to 10, 2).map((_, 1)).reduceByKey(_ + _) + // materialize the shuffle map outputs + rdd.count() + + sc.parallelize(1 to 10, 2).barrier().mapPartitions { iter => + var dataFile = SparkEnv.get.blockManager + .diskBlockManager.getFile(ShuffleDataBlockId(0, 0, 0)) + if (!dataFile.exists()) { + dataFile = SparkEnv.get.blockManager + .diskBlockManager.getFile(ShuffleDataBlockId(0, 1, 0)) + } + + if (dataFile.exists()) { + val f = new FileOutputStream(dataFile, true) + val ch = f.getChannel + // corrupt the shuffle data files by writing some arbitrary bytes + ch.write(ByteBuffer.wrap(Array[Byte](12)), 0) + ch.close() + } + BarrierTaskContext.get().barrier() + iter + }.collect() + + val e = intercept[SparkException] { + rdd.count() + } + assert(e.getMessage.contains("corrupted due to DISK issue")) + } + test("cannot find its local shuffle file if no execution of the stage and rerun shuffle") { sc = new SparkContext("local", "test", conf.clone()) val rdd = sc.parallelize(1 to 10, 1).map((_, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 7fd0bf626fda1..33bc6ca6f51c8 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -76,8 +76,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) - when(blockResolver.writeIndexFileAndCommit( - anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) + when(blockResolver.writeMetadataFileAndCommit( + anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index da98ad3d1c982..4e33bed27c807 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -72,7 +72,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths, Array.empty[Long], dataTmp) val indexFile = new File(tempDir.getAbsolutePath, idxName) val dataFile = resolver.getDataFile(shuffleId, mapId) @@ -92,7 +92,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths2, Array.empty[Long], dataTmp2) assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) @@ -131,7 +131,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths3, Array.empty[Long], dataTmp3) assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index ef5c615bf7591..bb24869e41519 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -74,8 +74,8 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA .set("spark.app.id", "example.spark.app") .set("spark.shuffle.unsafe.file.output.buffer", "16k") when(blockResolver.getDataFile(anyInt, anyLong)).thenReturn(mergedOutputFile) - when(blockResolver.writeIndexFileAndCommit( - anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) + when(blockResolver.writeMetadataFileAndCommit( + anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]