Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class BatchShuffleReadBufferPool {
* Memory size in bytes can be allocated from this buffer pool for a single request (4M is for
* better sequential read).
*/
private static final int NUM_BYTES_PER_REQUEST = 4 * 1024 * 1024;
public static final int NUM_BYTES_PER_REQUEST = 4 * 1024 * 1024;

/**
* Wait for at most 2 seconds before return if there is no enough available buffers currently.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ protected void setupInternal() throws IOException {
}
}

// reserve the "guaranteed" buffers for this buffer pool to avoid the case that those
// buffers are taken by other result partitions and can not be released, which may cause
// deadlock
requestGuaranteedBuffers();

// initialize the buffer pool eagerly to avoid reporting errors such as OOM too late
readBufferPool.initialize();
LOG.info("Sort-merge partition {} initialized.", getPartitionId());
Expand Down Expand Up @@ -325,7 +330,7 @@ private DataBuffer createNewDataBuffer() throws IOException {
}
}

private void requestNetworkBuffers() throws IOException {
private void requestGuaranteedBuffers() throws IOException {
int numRequiredBuffer = bufferPool.getNumberOfRequiredMemorySegments();
if (numRequiredBuffer < 2) {
throw new IOException(
Expand All @@ -339,8 +344,13 @@ private void requestNetworkBuffers() throws IOException {
freeSegments.add(checkNotNull(bufferPool.requestMemorySegmentBlocking()));
}
} catch (InterruptedException exception) {
freeSegments.forEach(bufferPool::recycle);
throw new IOException("Failed to allocate buffers for result partition.", exception);
}
}

private void requestNetworkBuffers() throws IOException {
requestGuaranteedBuffers();

// avoid taking too many buffers in one result partition
while (freeSegments.size() < bufferPool.getMaxNumberOfMemorySegments()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.Collection;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
Expand Down Expand Up @@ -325,7 +326,7 @@ void testReleaseWhileWriting() throws Exception {

BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0);
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);

partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0);
partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1);
Expand All @@ -348,7 +349,7 @@ void testRelease() throws Exception {

BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0);
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);

partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 0);
partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 1);
Expand Down Expand Up @@ -381,7 +382,7 @@ void testCloseReleasesAllBuffers() throws Exception {

BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers);
SortMergeResultPartition partition = createSortMergedPartition(10, bufferPool);
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0);
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);

partition.emitRecord(ByteBuffer.allocate(bufferSize * (numBuffers - 1)), 5);
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers())
Expand Down Expand Up @@ -423,6 +424,107 @@ void testNumBytesProducedCounterForBroadcast() throws IOException {
testResultPartitionBytesCounter(true);
}

@TestTemplate
void testNetworkBufferReservation() throws IOException {
int numBuffers = 10;

BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 2 * numBuffers);
SortMergeResultPartition partition = createSortMergedPartition(1, bufferPool);
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);

partition.finish();
partition.close();
}

@TestTemplate
void testNoDeadlockOnSpecificConsumptionOrder() throws Exception {
// see https://issues.apache.org/jira/browse/FLINK-31386 for more information
int numNetworkBuffers = 2 * BatchShuffleReadBufferPool.NUM_BYTES_PER_REQUEST / bufferSize;
NetworkBufferPool networkBufferPool = new NetworkBufferPool(numNetworkBuffers, bufferSize);
BatchShuffleReadBufferPool readBufferPool =
new BatchShuffleReadBufferPool(
BatchShuffleReadBufferPool.NUM_BYTES_PER_REQUEST, bufferSize);

BufferPool bufferPool =
networkBufferPool.createBufferPool(numNetworkBuffers, numNetworkBuffers);
SortMergeResultPartition partition =
createSortMergedPartition(1, bufferPool, readBufferPool);
for (int i = 0; i < numNetworkBuffers; ++i) {
partition.emitRecord(ByteBuffer.allocate(bufferSize), 0);
}
partition.finish();
partition.close();

CountDownLatch condition1 = new CountDownLatch(1);
CountDownLatch condition2 = new CountDownLatch(1);

Runnable task1 =
() -> {
try {
ResultSubpartitionView view = partition.createSubpartitionView(0, listener);
BufferPool bufferPool1 =
networkBufferPool.createBufferPool(
numNetworkBuffers / 2, numNetworkBuffers);
SortMergeResultPartition partition1 =
createSortMergedPartition(1, bufferPool1);
readAndEmitData(view, partition1);

condition1.countDown();
condition2.await();
readAndEmitAllData(view, partition1);
} catch (Exception ignored) {
}
};
Thread consumer1 = new Thread(task1);
consumer1.start();

Runnable task2 =
() -> {
try {
condition1.await();
BufferPool bufferPool2 =
networkBufferPool.createBufferPool(
numNetworkBuffers / 2, numNetworkBuffers);
condition2.countDown();

SortMergeResultPartition partition2 =
createSortMergedPartition(1, bufferPool2);
ResultSubpartitionView view = partition.createSubpartitionView(0, listener);
readAndEmitAllData(view, partition2);
} catch (Exception ignored) {
}
};
Thread consumer2 = new Thread(task2);
consumer2.start();

consumer1.join();
consumer2.join();
}

private boolean readAndEmitData(ResultSubpartitionView view, SortMergeResultPartition partition)
throws Exception {
MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
ResultSubpartition.BufferAndBacklog buffer;
do {
buffer = view.getNextBuffer();
if (buffer != null) {
Buffer data = ((CompositeBuffer) buffer.buffer()).getFullBufferData(segment);
partition.emitRecord(data.getNioBufferReadable(), 0);
if (!data.isRecycled()) {
data.recycleBuffer();
}
return buffer.buffer().isBuffer();
}
} while (true);
}

private void readAndEmitAllData(ResultSubpartitionView view, SortMergeResultPartition partition)
throws Exception {
while (readAndEmitData(view, partition)) {}
partition.finish();
partition.close();
}

private void testResultPartitionBytesCounter(boolean isBroadcast) throws IOException {
int numBuffers = useHashDataBuffer ? 100 : 15;
int numSubpartitions = 2;
Expand Down