diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java index fc9aa1b0ef224..863c0429400c1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/disk/BatchShuffleReadBufferPool.java @@ -54,7 +54,7 @@ public class BatchShuffleReadBufferPool { * Memory size in bytes can be allocated from this buffer pool for a single request (4M is for * better sequential read). */ - private static final int NUM_BYTES_PER_REQUEST = 4 * 1024 * 1024; + public static final int NUM_BYTES_PER_REQUEST = 4 * 1024 * 1024; /** * Wait for at most 2 seconds before return if there is no enough available buffers currently. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java index def80f4057489..ba84f8b6e0b93 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java @@ -186,6 +186,11 @@ protected void setupInternal() throws IOException { } } + // reserve the "guaranteed" buffers for this buffer pool to avoid the case that those + // buffers are taken by other result partitions and can not be released, which may cause + // deadlock + requestGuaranteedBuffers(); + // initialize the buffer pool eagerly to avoid reporting errors such as OOM too late readBufferPool.initialize(); LOG.info("Sort-merge partition {} initialized.", getPartitionId()); @@ -325,7 +330,7 @@ private DataBuffer createNewDataBuffer() throws IOException { } } - private void requestNetworkBuffers() throws IOException { + private void requestGuaranteedBuffers() throws IOException { int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments(); if (numRequiredBuffer < 2) { throw new IOException( @@ -339,8 +344,13 @@ private void requestNetworkBuffers() throws IOException { freeSegments.add(checkNotNull(bufferPool.requestMemorySegmentBlocking())); } } catch (InterruptedException exception) { + freeSegments.forEach(bufferPool::recycle); throw new IOException("Failed to allocate buffers for result partition.", exception); } + } + + private void requestNetworkBuffers() throws IOException { + requestGuaranteedBuffers(); // avoid taking too many buffers in one result partition while (freeSegments.size() < bufferPool.getMaxNumberOfMemorySegments()) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java index 6f01823c3fd9e..c202669e9f328 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartitionTest.java @@ -50,6 +50,7 @@ import java.util.Collection; import java.util.Queue; import java.util.Random; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.Consumer; @@ -325,7 +326,7 @@ void testReleaseWhileWriting() throws Exception { BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool); - assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1); @@ -348,7 +349,7 @@ void testRelease() throws Exception { BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool); - assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1); @@ -381,7 +382,7 @@ void testCloseReleasesAllBuffers() throws Exception { BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool); - assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 5); assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()) @@ -423,6 +424,107 @@ void testNumBytesProducedCounterForBroadcast() throws IOException { testResultPartitionBytesCounter(true); } + @TestTemplate + void testNetworkBufferReservation() throws IOException { + int numBuffers = 10; + + BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 2 * numBuffers); + SortMergeResultPartition partition = createSortMergedPartition(1, bufferPool); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); + + partition.finish(); + partition.close(); + } + + @TestTemplate + void testNoDeadlockOnSpecificConsumptionOrder() throws Exception { + // see https://issues.apache.org/jira/browse/FLINK-31386 for more information + int numNetworkBuffers = 2 * BatchShuffleReadBufferPool.NUM_BYTES_PER_REQUEST / bufferSize; + NetworkBufferPool networkBufferPool = new NetworkBufferPool(numNetworkBuffers, bufferSize); + BatchShuffleReadBufferPool readBufferPool = + new BatchShuffleReadBufferPool( + BatchShuffleReadBufferPool.NUM_BYTES_PER_REQUEST, bufferSize); + + BufferPool bufferPool = + networkBufferPool.createBufferPool(numNetworkBuffers, numNetworkBuffers); + SortMergeResultPartition partition = + createSortMergedPartition(1, bufferPool, readBufferPool); + for (int i = 0; i < numNetworkBuffers; ++i) { + partition.emitRecord(ByteBuffer.allocate(bufferSize), 0); + } + partition.finish(); + partition.close(); + + CountDownLatch condition1 = new CountDownLatch(1); + CountDownLatch condition2 = new CountDownLatch(1); + + Runnable task1 = + () -> { + try { + ResultSubpartitionView view = partition.createSubpartitionView(0, listener); + BufferPool bufferPool1 = + networkBufferPool.createBufferPool( + numNetworkBuffers / 2, numNetworkBuffers); + SortMergeResultPartition partition1 = + createSortMergedPartition(1, bufferPool1); + readAndEmitData(view, partition1); + + condition1.countDown(); + condition2.await(); + readAndEmitAllData(view, partition1); + } catch (Exception ignored) { + } + }; + Thread consumer1 = new Thread(task1); + consumer1.start(); + + Runnable task2 = + () -> { + try { + condition1.await(); + BufferPool bufferPool2 = + networkBufferPool.createBufferPool( + numNetworkBuffers / 2, numNetworkBuffers); + condition2.countDown(); + + SortMergeResultPartition partition2 = + createSortMergedPartition(1, bufferPool2); + ResultSubpartitionView view = partition.createSubpartitionView(0, listener); + readAndEmitAllData(view, partition2); + } catch (Exception ignored) { + } + }; + Thread consumer2 = new Thread(task2); + consumer2.start(); + + consumer1.join(); + consumer2.join(); + } + + private boolean readAndEmitData(ResultSubpartitionView view, SortMergeResultPartition partition) + throws Exception { + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize); + ResultSubpartition.BufferAndBacklog buffer; + do { + buffer = view.getNextBuffer(); + if (buffer != null) { + Buffer data = ((CompositeBuffer) buffer.buffer()).getFullBufferData(segment); + partition.emitRecord(data.getNioBufferReadable(), 0); + if (!data.isRecycled()) { + data.recycleBuffer(); + } + return buffer.buffer().isBuffer(); + } + } while (true); + } + + private void readAndEmitAllData(ResultSubpartitionView view, SortMergeResultPartition partition) + throws Exception { + while (readAndEmitData(view, partition)) {} + partition.finish(); + partition.close(); + } + private void testResultPartitionBytesCounter(boolean isBroadcast) throws IOException { int numBuffers = useHashDataBuffer ? 100 : 15; int numSubpartitions = 2;