From 7b5e8691797be18987cded1131ef6e4e18889bf3 Mon Sep 17 00:00:00 2001 From: "kevin.cyj" Date: Thu, 9 Mar 2023 11:52:21 +0800 Subject: [PATCH] [FLINK-31386][network] Fix the potential deadlock issue of blocking shuffle Currently, the SortMergeResultPartition may allocate more network buffers than the guaranteed size of the LocalBufferPool. As a result, some result partitions may need to wait other result partitions to release the over-allocated network buffers to continue. However, the result partitions which have allocated more than guaranteed buffers relies on the processing of input data to trigger data spilling and buffer recycling. The input data further relies on batch reading buffers used by the SortMergeResultPartitionReadScheduler which may already taken by those blocked result partitions that are waiting for buffers. Then deadlock occurs. This patch fixes the deadlock issue by reserving the guaranteed buffers on initializing. This closes #22148. --- .../io/disk/BatchShuffleReadBufferPool.java | 2 +- .../partition/SortMergeResultPartition.java | 12 +- .../SortMergeResultPartitionTest.java | 108 +++++++++++++++++- 3 files changed, 117 insertions(+), 5 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java index fc9aa1b0ef224..863c0429400c1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java @@ -54,7 +54,7 @@ public class BatchShuffleReadBufferPool { * Memory size in bytes can be allocated from this buffer pool for a single request (4M is for * better sequential read). */ - private static final int NUM_BYTES_PER_REQUEST = 4 * 1024 * 1024; + public static final int NUM_BYTES_PER_REQUEST = 4 * 1024 * 1024; /** * Wait for at most 2 seconds before return if there is no enough available buffers currently. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java index def80f4057489..ba84f8b6e0b93 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java @@ -186,6 +186,11 @@ protected void setupInternal() throws IOException { } } + // reserve the "guaranteed" buffers for this buffer pool to avoid the case that those + // buffers are taken by other result partitions and can not be released, which may cause + // deadlock + requestGuaranteedBuffers(); + // initialize the buffer pool eagerly to avoid reporting errors such as OOM too late readBufferPool.initialize(); LOG.info("Sort-merge partition {} initialized.", getPartitionId()); @@ -325,7 +330,7 @@ private DataBuffer createNewDataBuffer() throws IOException { } } - private void requestNetworkBuffers() throws IOException { + private void requestGuaranteedBuffers() throws IOException { int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments(); if (numRequiredBuffer < 2) { throw new IOException( @@ -339,8 +344,13 @@ private void requestNetworkBuffers() throws IOException { freeSegments.add(checkNotNull(bufferPool.requestMemorySegmentBlocking())); } } catch (InterruptedException exception) { + freeSegments.forEach(bufferPool::recycle); throw new IOException("Failed to allocate buffers for result partition.", exception); } + } + + private void requestNetworkBuffers() throws IOException { + requestGuaranteedBuffers(); // avoid taking too many buffers in one result partition while (freeSegments.size() < bufferPool.getMaxNumberOfMemorySegments()) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java index 6f01823c3fd9e..c202669e9f328 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java @@ -50,6 +50,7 @@ import java.util.Collection; import java.util.Queue; import java.util.Random; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.Consumer; @@ -325,7 +326,7 @@ void testReleaseWhileWriting() throws Exception { BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool); - assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1); @@ -348,7 +349,7 @@ void testRelease() throws Exception { BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool); - assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1); @@ -381,7 +382,7 @@ void testCloseReleasesAllBuffers() throws Exception { BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool); - assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 5); assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()) @@ -423,6 +424,107 @@ void testNumBytesProducedCounterForBroadcast() throws IOException { testResultPartitionBytesCounter(true); } + @TestTemplate + void testNetworkBufferReservation() throws IOException { + int numBuffers = 10; + + BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 2 * numBuffers); + SortMergeResultPartition partition = createSortMergedPartition(1, bufferPool); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); + + partition.finish(); + partition.close(); + } + + @TestTemplate + void testNoDeadlockOnSpecificConsumptionOrder() throws Exception { + // see https://issues.apache.org/jira/browse/FLINK-31386 for more information + int numNetworkBuffers = 2 * BatchShuffleReadBufferPool.NUM_BYTES_PER_REQUEST / bufferSize; + NetworkBufferPool networkBufferPool = new NetworkBufferPool(numNetworkBuffers, bufferSize); + BatchShuffleReadBufferPool readBufferPool = + new BatchShuffleReadBufferPool( + BatchShuffleReadBufferPool.NUM_BYTES_PER_REQUEST, bufferSize); + + BufferPool bufferPool = + networkBufferPool.createBufferPool(numNetworkBuffers, numNetworkBuffers); + SortMergeResultPartition partition = + createSortMergedPartition(1, bufferPool, readBufferPool); + for (int i = 0; i < numNetworkBuffers; ++i) { + partition.emitRecord(ByteBuffer.allocate(bufferSize), 0); + } + partition.finish(); + partition.close(); + + CountDownLatch condition1 = new CountDownLatch(1); + CountDownLatch condition2 = new CountDownLatch(1); + + Runnable task1 = + () -> { + try { + ResultSubpartitionView view = partition.createSubpartitionView(0, listener); + BufferPool bufferPool1 = + networkBufferPool.createBufferPool( + numNetworkBuffers / 2, numNetworkBuffers); + SortMergeResultPartition partition1 = + createSortMergedPartition(1, bufferPool1); + readAndEmitData(view, partition1); + + condition1.countDown(); + condition2.await(); + readAndEmitAllData(view, partition1); + } catch (Exception ignored) { + } + }; + Thread consumer1 = new Thread(task1); + consumer1.start(); + + Runnable task2 = + () -> { + try { + condition1.await(); + BufferPool bufferPool2 = + networkBufferPool.createBufferPool( + numNetworkBuffers / 2, numNetworkBuffers); + condition2.countDown(); + + SortMergeResultPartition partition2 = + createSortMergedPartition(1, bufferPool2); + ResultSubpartitionView view = partition.createSubpartitionView(0, listener); + readAndEmitAllData(view, partition2); + } catch (Exception ignored) { + } + }; + Thread consumer2 = new Thread(task2); + consumer2.start(); + + consumer1.join(); + consumer2.join(); + } + + private boolean readAndEmitData(ResultSubpartitionView view, SortMergeResultPartition partition) + throws Exception { + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize); + ResultSubpartition.BufferAndBacklog buffer; + do { + buffer = view.getNextBuffer(); + if (buffer != null) { + Buffer data = ((CompositeBuffer) buffer.buffer()).getFullBufferData(segment); + partition.emitRecord(data.getNioBufferReadable(), 0); + if (!data.isRecycled()) { + data.recycleBuffer(); + } + return buffer.buffer().isBuffer(); + } + } while (true); + } + + private void readAndEmitAllData(ResultSubpartitionView view, SortMergeResultPartition partition) + throws Exception { + while (readAndEmitData(view, partition)) {} + partition.finish(); + partition.close(); + } + private void testResultPartitionBytesCounter(boolean isBroadcast) throws IOException { int numBuffers = useHashDataBuffer ? 100 : 15; int numSubpartitions = 2;