From 3556bc644e85bb1d25a1af537760ad97f1df0d00 Mon Sep 17 00:00:00 2001 From: jiangxin Date: Wed, 20 Dec 2023 20:38:55 +0800 Subject: [PATCH] [FLINK-33879] Avoids the potential hang of Hybrid Shuffle during redistribution --- .../tiered/shuffle/TieredResultPartition.java | 1 + .../shuffle/TieredResultPartitionFactory.java | 17 +- .../storage/TieredStorageMemoryManager.java | 30 +++- .../TieredStorageMemoryManagerImpl.java | 161 +++++++++++++++--- .../storage/TieredStorageMemorySpec.java | 16 ++ .../tier/memory/MemoryTierProducerAgent.java | 3 +- .../TestingTieredStorageMemoryManager.java | 18 ++ .../shuffle/TieredResultPartitionTest.java | 91 ++++++++-- .../TieredStorageMemoryManagerImplTest.java | 79 ++++++++- .../memory/MemoryTierProducerAgentTest.java | 1 + 10 files changed, 372 insertions(+), 45 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartition.java index f7070165fa1fd..f69d7a42bcddd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartition.java @@ -187,6 +187,7 @@ public void finish() throws IOException { @Override public void close() { + storageMemoryManager.release(); super.close(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionFactory.java index 538fda6a0f097..6390979ec8344 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionFactory.java @@ -176,13 +176,16 @@ private BufferAccumulator createBufferAccumulator( List tieredStorageMemorySpecs = new ArrayList<>(); tieredStorageMemorySpecs.add( + // Accumulators are also treated as {@code guaranteedReclaimable}, since these + // buffers can always be transferred to the other tiers. new TieredStorageMemorySpec( bufferAccumulator, 2 * Math.min( numberOfSubpartitions + 1, tieredStorageConfiguration - .getAccumulatorExclusiveBuffers()))); + .getAccumulatorExclusiveBuffers()), + true)); List tierExclusiveBuffers = tieredStorageConfiguration.getEachTierExclusiveBufferNum(); @@ -208,8 +211,16 @@ private BufferAccumulator createBufferAccumulator( numberOfSubpartitions), tieredStorageConfiguration.getDiskIOSchedulerBufferRequestTimeout()); tierProducerAgents.add(producerAgent); - tieredStorageMemorySpecs.add( - new TieredStorageMemorySpec(producerAgent, tierExclusiveBuffers.get(index))); + + if (tierFactory.getClass() == MemoryTierFactory.class) { + tieredStorageMemorySpecs.add( + new TieredStorageMemorySpec( + producerAgent, tierExclusiveBuffers.get(index), false)); + } else { + tieredStorageMemorySpecs.add( + new TieredStorageMemorySpec( + producerAgent, tierExclusiveBuffers.get(index), true)); + } } return Tuple2.of(tierProducerAgents, tieredStorageMemorySpecs); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManager.java index 6b44eba76ad12..108c74767d537 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManager.java @@ -42,6 +42,11 @@ * can request. Instead, it only simply provides memory usage hints to memory users. It is very * important to note that only users with non-reclaimable should check the memory * hints by calling {@code getMaxNonReclaimableBuffers} before requesting buffers. + * + *

The {@link TieredStorageMemoryManager} needs to ensure that it would not hinder reclaimable + * users from acquiring buffers due to non-reclaimable users not releasing the buffers they have + * requested. So it is very important to note that only users with non-reclaimable + * should call {@code ensureCapacity} before requesting buffers to reserve enough buffers. */ public interface TieredStorageMemoryManager { @@ -74,11 +79,11 @@ public interface TieredStorageMemoryManager { void listenBufferReclaimRequest(Runnable onBufferReclaimRequest); /** - * Request a {@link BufferBuilder} instance from {@link BufferPool} for a specific owner. The - * {@link TieredStorageMemoryManagerImpl} will not check whether a buffer can be requested. The - * manager only records the number of requested buffers. If the buffers in the {@link - * BufferPool} is not enough, the manager will request each tiered storage to reclaim their - * requested buffers as much as possible. + * Request a {@link BufferBuilder} instance for a specific owner. The {@link + * TieredStorageMemoryManagerImpl} will not check whether a buffer can be requested. The manager + * only records the number of requested buffers. If the buffers is not enough to meet the + * request, the manager will request each tiered storage to reclaim their requested buffers as + * much as possible. * *

This is not thread safe and is expected to be called only from the task thread. * @@ -101,6 +106,21 @@ public interface TieredStorageMemoryManager { */ int getMaxNonReclaimableBuffers(Object owner); + /** + * Try best to reserve enough buffers that are guaranteed reclaimable along with the additional + * ones. + * + *

Note that the available buffers are calculated dynamically based on some conditions, for + * example, the state of the {@link BufferPool}, the {@link TieredStorageMemorySpec} of the + * owner, etc. So the caller should always ensure capacity before requesting non-reclaimable + * buffers. + * + * @param numAdditionalBuffers the number of buffers that need to also be reserved in addition + * to guaranteed reclaimable buffers. + * @return True if the capacity meets the requirements, false otherwise. + */ + boolean ensureCapacity(int numAdditionalBuffers); + /** * Return the number of requested buffers belonging to a specific owner. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImpl.java index dee3c65867eca..bf5081e0721eb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImpl.java @@ -31,18 +31,23 @@ import org.apache.flink.shaded.guava31.com.google.common.util.concurrent.ThreadFactoryBuilder; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; @@ -91,6 +96,18 @@ public class TieredStorageMemoryManagerImpl implements TieredStorageMemoryManage */ private final Map numOwnerRequestedBuffers; + /** + * The queue that contains all available buffers. This field should be thread-safe because it + * can be touched both by the task thread and the netty thread. + */ + private final BlockingQueue bufferQueue; + + /** The lock guarding concurrency issues during releasing. */ + private final ReadWriteLock releasedStateLock; + + /** The number of buffers that are guaranteed to be reclaimed. */ + private int numGuaranteedReclaimableBuffers; + /** * Time gauge to measure that hard backpressure time. Pre-create it to avoid checkNotNull in * hot-path for performance purpose. @@ -108,7 +125,7 @@ public class TieredStorageMemoryManagerImpl implements TieredStorageMemoryManage private BufferPool bufferPool; /** - * Indicate whether the {@link TieredStorageMemoryManagerImpl} is initialized. Before setting + * Indicates whether the {@link TieredStorageMemoryManagerImpl} is initialized. Before setting * up, this field is false. * *

Note that before requesting buffers or getting the maximum allowed buffers, this @@ -116,6 +133,15 @@ public class TieredStorageMemoryManagerImpl implements TieredStorageMemoryManage */ private boolean isInitialized; + /** + * Indicates whether the {@link TieredStorageMemoryManagerImpl} is released. + * + *

Note that before recycling buffers, this released state should be checked to determine + * whether to recycle the buffer back to the internal queue or to the buffer pool. + */ + @GuardedBy("readWriteLock") + private boolean isReleased; + /** * The constructor of the {@link TieredStorageMemoryManagerImpl}. * @@ -131,6 +157,9 @@ public TieredStorageMemoryManagerImpl( this.numRequestedBuffers = new AtomicInteger(0); this.numOwnerRequestedBuffers = new ConcurrentHashMap<>(); this.bufferReclaimRequestListeners = new ArrayList<>(); + this.bufferQueue = new LinkedBlockingQueue<>(); + this.releasedStateLock = new ReentrantReadWriteLock(); + this.isReleased = false; this.isInitialized = false; } @@ -142,6 +171,8 @@ public void setup(BufferPool bufferPool, List storageMe !tieredMemorySpecs.containsKey(memorySpec.getOwner()), "Duplicated memory spec."); tieredMemorySpecs.put(memorySpec.getOwner(), memorySpec); + numGuaranteedReclaimableBuffers += + memorySpec.isGuaranteedReclaimable() ? memorySpec.getNumGuaranteedBuffers() : 0; } if (mayReclaimBuffer) { @@ -173,22 +204,13 @@ public BufferBuilder requestBufferBlocking(Object owner) { reclaimBuffersIfNeeded(0); - CompletableFuture requestBufferFuture = new CompletableFuture<>(); - scheduleCheckRequestBufferFuture( - requestBufferFuture, INITIAL_REQUEST_BUFFER_TIMEOUT_FOR_RECLAIMING_MS); - MemorySegment memorySegment = bufferPool.requestMemorySegment(); - + MemorySegment memorySegment = bufferQueue.poll(); if (memorySegment == null) { - try { - hardBackpressureTimerGauge.markStart(); - memorySegment = bufferPool.requestMemorySegmentBlocking(); - hardBackpressureTimerGauge.markEnd(); - } catch (InterruptedException e) { - ExceptionUtils.rethrow(e); - } + memorySegment = requestBufferBlockingFromPool(); + } + if (memorySegment == null) { + memorySegment = checkNotNull(requestBufferBlockingFromQueue()); } - - requestBufferFuture.complete(null); incNumRequestedBuffer(owner); return new BufferBuilder( @@ -218,6 +240,31 @@ public int getMaxNonReclaimableBuffers(Object owner) { return bufferPool.getNumBuffers() - numBuffersUsedOrReservedForOtherOwners; } + @Override + public boolean ensureCapacity(int numAdditionalBuffers) { + checkIsInitialized(); + + final int numRequestedByGuaranteedReclaimableOwners = + tieredMemorySpecs.values().stream() + .filter(TieredStorageMemorySpec::isGuaranteedReclaimable) + .mapToInt(spec -> numOwnerRequestedBuffer(spec.getOwner())) + .sum(); + + while (bufferQueue.size() + numRequestedByGuaranteedReclaimableOwners + < numGuaranteedReclaimableBuffers + numAdditionalBuffers) { + if (numRequestedBuffers.get() >= bufferPool.getNumBuffers()) { + return false; + } + + MemorySegment memorySegment = requestBufferBlockingFromPool(); + if (memorySegment == null) { + return false; + } + bufferQueue.add(memorySegment); + } + return true; + } + @Override public int numOwnerRequestedBuffer(Object owner) { return numOwnerRequestedBuffers.getOrDefault(owner, 0); @@ -233,6 +280,12 @@ public void transferBufferOwnership(Object oldOwner, Object newOwner, Buffer buf @Override public void release() { + try { + releasedStateLock.writeLock().lock(); + isReleased = true; + } finally { + releasedStateLock.writeLock().unlock(); + } if (executor != null) { executor.shutdown(); try { @@ -244,6 +297,59 @@ public void release() { ExceptionUtils.rethrow(e); } } + while (!bufferQueue.isEmpty()) { + MemorySegment segment = bufferQueue.poll(); + bufferPool.recycle(segment); + numRequestedBuffers.decrementAndGet(); + } + } + + /** + * @return a memory segment from the buffer pool or null if the memory manager has requested all + * segments of the buffer pool. + */ + @Nullable + private MemorySegment requestBufferBlockingFromPool() { + MemorySegment memorySegment = null; + + hardBackpressureTimerGauge.markStart(); + while (numRequestedBuffers.get() < bufferPool.getNumBuffers()) { + memorySegment = bufferPool.requestMemorySegment(); + if (memorySegment == null) { + try { + // Wait until a buffer is available or timeout before entering the next loop + // iteration. + bufferPool.getAvailableFuture().get(100, TimeUnit.MILLISECONDS); + } catch (TimeoutException ignored) { + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + } else { + numRequestedBuffers.incrementAndGet(); + break; + } + } + hardBackpressureTimerGauge.markEnd(); + + return memorySegment; + } + + /** @return a memory segment from the internal buffer queue. */ + private MemorySegment requestBufferBlockingFromQueue() { + CompletableFuture requestBufferFuture = new CompletableFuture<>(); + scheduleCheckRequestBufferFuture( + requestBufferFuture, INITIAL_REQUEST_BUFFER_TIMEOUT_FOR_RECLAIMING_MS); + + MemorySegment memorySegment = null; + try { + memorySegment = bufferQueue.take(); + } catch (InterruptedException e) { + ExceptionUtils.rethrow(e); + } finally { + requestBufferFuture.complete(null); + } + + return memorySegment; } private void scheduleCheckRequestBufferFuture( @@ -272,13 +378,11 @@ private void internalCheckRequestBufferFuture( private void incNumRequestedBuffer(Object owner) { numOwnerRequestedBuffers.compute( owner, (ignore, numRequested) -> numRequested == null ? 1 : numRequested + 1); - numRequestedBuffers.incrementAndGet(); } private void decNumRequestedBuffer(Object owner) { numOwnerRequestedBuffers.compute( owner, (ignore, numRequested) -> checkNotNull(numRequested) - 1); - numRequestedBuffers.decrementAndGet(); } private void reclaimBuffersIfNeeded(long delayForNextCheckMs) { @@ -293,17 +397,28 @@ private boolean shouldReclaimBuffersBeforeRequesting(long delayForNextCheckMs) { // next iteration, the buffer reclaim will eventually be triggered. int numTotal = bufferPool.getNumBuffers(); int numRequested = numRequestedBuffers.get(); - return numRequested >= numTotal - // Because we do the checking before requesting buffers, we need add additional one - // buffer when calculating the usage ratio. - || ((numRequested + 1) * 1.0 / numTotal) > numTriggerReclaimBuffersRatio + + // Because we do the checking before requesting buffers, we need add additional one + // buffer when calculating the usage ratio. + return (numRequested + 1 - bufferQueue.size()) * 1.0 / numTotal + > numTriggerReclaimBuffersRatio || delayForNextCheckMs > MAX_DELAY_TIME_TO_TRIGGER_RECLAIM_BUFFER_MS - && bufferPool.getNumberOfAvailableMemorySegments() == 0; + && bufferQueue.size() == 0; } /** Note that this method may be called by the netty thread. */ private void recycleBuffer(Object owner, MemorySegment buffer) { - bufferPool.recycle(buffer); + try { + releasedStateLock.readLock().lock(); + if (!isReleased && numRequestedBuffers.get() <= bufferPool.getNumBuffers()) { + bufferQueue.add(buffer); + } else { + bufferPool.recycle(buffer); + numRequestedBuffers.decrementAndGet(); + } + } finally { + releasedStateLock.readLock().unlock(); + } decNumRequestedBuffer(owner); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemorySpec.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemorySpec.java index 6a2b9b71f9fd3..a97cf3725a539 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemorySpec.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemorySpec.java @@ -30,9 +30,21 @@ public class TieredStorageMemorySpec { /** The number of guaranteed buffers of this memory owner. */ private final int numGuaranteedBuffers; + /** + * Whether the buffers of this owner are guaranteed to be reclaimed, even if the downstream does + * not consume them promptly. + */ + private final boolean guaranteedReclaimable; + public TieredStorageMemorySpec(Object owner, int numGuaranteedBuffers) { + this(owner, numGuaranteedBuffers, true); + } + + public TieredStorageMemorySpec( + Object owner, int numGuaranteedBuffers, boolean guaranteedReclaimable) { this.owner = owner; this.numGuaranteedBuffers = numGuaranteedBuffers; + this.guaranteedReclaimable = guaranteedReclaimable; } public Object getOwner() { @@ -42,4 +54,8 @@ public Object getOwner() { public int getNumGuaranteedBuffers() { return numGuaranteedBuffers; } + + public boolean isGuaranteedReclaimable() { + return guaranteedReclaimable; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierProducerAgent.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierProducerAgent.java index 97641af1970fe..fbbfc0595b92c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierProducerAgent.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierProducerAgent.java @@ -109,7 +109,8 @@ public boolean tryStartNewSegment(TieredStorageSubpartitionId subpartitionId, in < subpartitionMaxQueuedBuffers && (memoryManager.getMaxNonReclaimableBuffers(this) - memoryManager.numOwnerRequestedBuffer(this)) - > numBuffersPerSegment; + > numBuffersPerSegment + && memoryManager.ensureCapacity(numBuffersPerSegment); if (canStartNewSegment) { subpartitionProducerAgents[subpartitionId.getSubpartitionId()].updateSegmentId( segmentId); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingTieredStorageMemoryManager.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingTieredStorageMemoryManager.java index c99013b6a1e70..d11172f789a6b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingTieredStorageMemoryManager.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingTieredStorageMemoryManager.java @@ -44,6 +44,8 @@ public class TestingTieredStorageMemoryManager implements TieredStorageMemoryMan private final Function getMaxNonReclaimableBuffersFunction; + private final Function ensureCapacityFunction; + private final Function numOwnerRequestedBufferFunction; private final TriConsumer transferBufferOwnershipConsumer; @@ -56,6 +58,7 @@ private TestingTieredStorageMemoryManager( Consumer listenBufferReclaimRequestConsumer, Function requestBufferBlockingFunction, Function getMaxNonReclaimableBuffersFunction, + Function ensureCapacityFunction, Function numOwnerRequestedBufferFunction, TriConsumer transferBufferOwnershipConsumer, Runnable releaseRunnable) { @@ -64,6 +67,7 @@ private TestingTieredStorageMemoryManager( this.listenBufferReclaimRequestConsumer = listenBufferReclaimRequestConsumer; this.requestBufferBlockingFunction = requestBufferBlockingFunction; this.getMaxNonReclaimableBuffersFunction = getMaxNonReclaimableBuffersFunction; + this.ensureCapacityFunction = ensureCapacityFunction; this.numOwnerRequestedBufferFunction = numOwnerRequestedBufferFunction; this.transferBufferOwnershipConsumer = transferBufferOwnershipConsumer; this.releaseRunnable = releaseRunnable; @@ -94,6 +98,11 @@ public int getMaxNonReclaimableBuffers(Object owner) { return getMaxNonReclaimableBuffersFunction.apply(owner); } + @Override + public boolean ensureCapacity(int numAdditionalBuffers) { + return ensureCapacityFunction.apply(numAdditionalBuffers); + } + @Override public int numOwnerRequestedBuffer(Object owner) { return numOwnerRequestedBufferFunction.apply(owner); @@ -123,6 +132,8 @@ public static class Builder { private Function getMaxNonReclaimableBuffersFunction = owner -> 0; + private Function ensureCapacityFunction = num -> true; + private Function numOwnerRequestedBufferFunction = owner -> 0; private TriConsumer transferBufferOwnershipConsumer = @@ -156,6 +167,12 @@ public TestingTieredStorageMemoryManager.Builder setGetMaxNonReclaimableBuffersF return this; } + public TestingTieredStorageMemoryManager.Builder setEnsureCapacityFunction( + Function ensureCapacityFunction) { + this.ensureCapacityFunction = ensureCapacityFunction; + return this; + } + public TestingTieredStorageMemoryManager.Builder setNumOwnerRequestedBufferFunction( Function numOwnerRequestedBufferFunction) { this.numOwnerRequestedBufferFunction = numOwnerRequestedBufferFunction; @@ -181,6 +198,7 @@ public TestingTieredStorageMemoryManager build() { listenBufferReclaimRequestConsumer, requestBufferBlockingFunction, getMaxNonReclaimableBuffersFunction, + ensureCapacityFunction, numOwnerRequestedBufferFunction, transferBufferOwnershipConsumer, releaseRunnable); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionTest.java index c3a237efe09bb..7d6f1dfa88030 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionTest.java @@ -30,9 +30,11 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; import org.apache.flink.runtime.io.network.partition.hybrid.tiered.TestingBufferAccumulator; import org.apache.flink.runtime.io.network.partition.hybrid.tiered.TestingTierProducerAgent; import org.apache.flink.runtime.io.network.partition.hybrid.tiered.TestingTieredStorageMemoryManager; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageConfiguration; import org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty.TieredStorageNettyServiceImpl; import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageProducerClient; import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageResourceRegistry; @@ -176,12 +178,7 @@ void testEmitRecords() throws Exception { createTieredStoreResultPartition(2, bufferPool, false)) { partition.emitRecord(ByteBuffer.allocate(bufferSize), 0); partition.broadcastRecord(ByteBuffer.allocate(bufferSize)); - IOMetrics ioMetrics = taskIOMetricGroup.createSnapshot(); - assertThat(ioMetrics.getResultPartitionBytes()).hasSize(1); - ResultPartitionBytes partitionBytes = - ioMetrics.getResultPartitionBytes().values().iterator().next(); - assertThat(partitionBytes.getSubpartitionBytes()) - .containsExactly((long) 2 * bufferSize, bufferSize); + verifySubpartitionBytes((long) 2 * bufferSize, bufferSize); } } @@ -192,15 +189,37 @@ void testMetricsUpdateForBroadcastOnlyResultPartition() throws Exception { try (TieredResultPartition partition = createTieredStoreResultPartition(2, bufferPool, true)) { partition.broadcastRecord(ByteBuffer.allocate(bufferSize)); - IOMetrics ioMetrics = taskIOMetricGroup.createSnapshot(); - assertThat(ioMetrics.getResultPartitionBytes()).hasSize(1); - ResultPartitionBytes partitionBytes = - ioMetrics.getResultPartitionBytes().values().iterator().next(); - assertThat(partitionBytes.getSubpartitionBytes()) - .containsExactly(bufferSize, bufferSize); + verifySubpartitionBytes(bufferSize, bufferSize); } } + @Test + void testRequestBuffersAfterPoolSizeDecreased() throws IOException { + final int numBuffers = 20; + final int numRecords = numBuffers / 2; + + BufferPool bufferPool = globalPool.createBufferPool(1, numBuffers); + TieredResultPartition resultPartition = + createTieredStoreResultPartitionWithStorageManager(1, bufferPool, false); + + ResultSubpartitionView subpartitionView = + resultPartition.createSubpartitionView(0, new NoOpBufferAvailablityListener()); + + // Emits some records to occupy some buffers of memory tier, these buffers would not be + // recycled until the subpartitionView is released manually. + for (int i = 0; i < numRecords; i++) { + resultPartition.emitRecord(ByteBuffer.allocate(NETWORK_BUFFER_SIZE), 0); + } + verifySubpartitionBytes(numRecords * NETWORK_BUFFER_SIZE); + + bufferPool.setNumBuffers(1); + resultPartition.emitRecord(ByteBuffer.allocate(NETWORK_BUFFER_SIZE), 0); + verifySubpartitionBytes((numRecords + 1) * NETWORK_BUFFER_SIZE); + + subpartitionView.releaseAllResources(); + resultPartition.release(); + } + private TieredResultPartition createTieredStoreResultPartition( int numSubpartitions, BufferPool bufferPool, boolean isBroadcastOnly) throws IOException { @@ -234,4 +253,52 @@ private TieredResultPartition createTieredStoreResultPartition( tieredResultPartition.setMetricGroup(taskIOMetricGroup); return tieredResultPartition; } + + private TieredResultPartition createTieredStoreResultPartitionWithStorageManager( + int numSubpartitions, BufferPool bufferPool, boolean isBroadcastOnly) + throws IOException { + TieredStorageConfiguration tieredStorageConfiguration = + TieredStorageConfiguration.builder(null) + .setMemoryTierSubpartitionMaxQueuedBuffers(10) + .build(); + TieredStorageResourceRegistry tieredStorageResourceRegistry = + new TieredStorageResourceRegistry(); + TieredStorageNettyServiceImpl tieredStorageNettyService = + new TieredStorageNettyServiceImpl(tieredStorageResourceRegistry); + TieredResultPartitionFactory tieredResultPartitionFactory = + new TieredResultPartitionFactory( + tieredStorageConfiguration, + tieredStorageNettyService, + tieredStorageResourceRegistry); + + TieredResultPartition resultPartition = + tieredResultPartitionFactory.createTieredResultPartition( + "TieredStoreResultPartitionTest", + 0, + new ResultPartitionID(), + ResultPartitionType.HYBRID_SELECTIVE, + numSubpartitions, + numSubpartitions, + isBroadcastOnly, + new ResultPartitionManager(), + new BufferCompressor(NETWORK_BUFFER_SIZE, "LZ4"), + () -> bufferPool, + fileChannelManager, + readBufferPool, + readIOExecutor); + taskIOMetricGroup = + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(); + resultPartition.setup(); + resultPartition.setMetricGroup(taskIOMetricGroup); + + return resultPartition; + } + + private void verifySubpartitionBytes(long... expectedNumBytes) { + IOMetrics ioMetrics = taskIOMetricGroup.createSnapshot(); + assertThat(ioMetrics.getResultPartitionBytes()).hasSize(1); + ResultPartitionBytes partitionBytes = + ioMetrics.getResultPartitionBytes().values().iterator().next(); + assertThat(partitionBytes.getSubpartitionBytes()).containsExactly(expectedNumBytes); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImplTest.java index b395056505a5e..395713e48ab18 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImplTest.java @@ -33,6 +33,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -84,8 +85,54 @@ void testRequestAndRecycleBuffers() throws IOException { BufferBuilder builder = storageMemoryManager.requestBufferBlocking(this); assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(1); recycleBufferBuilder(builder); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(1); + storageMemoryManager.release(); assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0); + } + + @Test + void testRecycleBuffersAfterPoolSizeDecreased() throws IOException { + int numBuffers = 10; + + BufferPool bufferPool = globalPool.createBufferPool(1, numBuffers); + TieredStorageMemoryManagerImpl storageMemoryManager = + createStorageMemoryManager( + bufferPool, + Collections.singletonList(new TieredStorageMemorySpec(this, 0))); + for (int i = 0; i < numBuffers; i++) { + BufferBuilder builder = storageMemoryManager.requestBufferBlocking(this); + requestedBuffers.add(builder); + } + + bufferPool.setNumBuffers(numBuffers / 2); + + for (int i = 0; i < numBuffers; i++) { + recycleBufferBuilder(requestedBuffers.get(i)); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()) + .isEqualTo(Math.max(numBuffers / 2, numBuffers - (i + 1))); + } + } + + @Test + void testRecycleBuffersAfterReleased() throws IOException { + int numBuffers = 10; + + BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); + TieredStorageMemoryManagerImpl storageMemoryManager = + createStorageMemoryManager( + bufferPool, + Collections.singletonList(new TieredStorageMemorySpec(this, 0))); + for (int i = 0; i < numBuffers; i++) { + BufferBuilder builder = storageMemoryManager.requestBufferBlocking(this); + requestedBuffers.add(builder); + } + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); + storageMemoryManager.release(); + for (int i = 0; i < numBuffers; i++) { + recycleBufferBuilder(requestedBuffers.get(i)); + } + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(0); } @Test @@ -242,19 +289,49 @@ void testCanNotTransferOwnershipForEvent() throws IOException { .isInstanceOf(IllegalStateException.class); } + @Test + void testEnsureCapacity() throws IOException { + final int numBuffers = 5; + final int guaranteedReclaimableBuffers = 3; + + BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); + TieredStorageMemoryManagerImpl storageMemoryManager = + createStorageMemoryManager( + bufferPool, + Arrays.asList( + new TieredStorageMemorySpec( + new Object(), guaranteedReclaimableBuffers, true), + new TieredStorageMemorySpec(this, 0, false))); + assertThat(storageMemoryManager.ensureCapacity(0)).isTrue(); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()) + .isEqualTo(guaranteedReclaimableBuffers); + + assertThat(storageMemoryManager.ensureCapacity(numBuffers - guaranteedReclaimableBuffers)) + .isTrue(); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers); + + assertThat( + storageMemoryManager.ensureCapacity( + numBuffers - guaranteedReclaimableBuffers + 1)) + .isFalse(); + storageMemoryManager.release(); + } + @Test void testRelease() throws IOException { int numBuffers = 5; + BufferPool bufferPool = globalPool.createBufferPool(numBuffers, numBuffers); TieredStorageMemoryManagerImpl storageMemoryManager = createStorageMemoryManager( - numBuffers, + bufferPool, Collections.singletonList(new TieredStorageMemorySpec(this, 0))); requestedBuffers.add(storageMemoryManager.requestBufferBlocking(this)); assertThat(storageMemoryManager.numOwnerRequestedBuffer(this)).isOne(); recycleRequestedBuffers(); storageMemoryManager.release(); assertThat(storageMemoryManager.numOwnerRequestedBuffer(this)).isZero(); + assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isZero(); } public void onBufferReclaimRequest() { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierProducerAgentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierProducerAgentTest.java index 2b8bf2ad872a0..a0f62d6f2559a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierProducerAgentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierProducerAgentTest.java @@ -104,6 +104,7 @@ void testStartSegmentFailedWithInsufficientMemory() { TestingTieredStorageMemoryManager memoryManager = new TestingTieredStorageMemoryManager.Builder() .setGetMaxNonReclaimableBuffersFunction(ignore -> 1) + .setEnsureCapacityFunction(num -> false) .build(); TestingTieredStorageNettyService nettyService = new TestingTieredStorageNettyService.Builder().build();