From 46d4b743b90312315ea1f06f2c13dfa25ada0d52 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Thu, 21 May 2026 15:26:09 +0200 Subject: [PATCH 01/16] [FLINK-38544][network] Decouple LocalInputChannel recovery wiring from toBeConsumedBuffers Split the single toBeConsumedBuffers queue into two queues with disjoint responsibilities: - recoveredBuffers (new): holds buffers migrated from RecoveredInputChannel during construction; consumed by getNextRecoveredBuffer() which retains the priority-event interleaving and last-buffer dynamic next-data-type detection introduced by FLINK-39018. - toBeConsumedBuffers (existing): reverted to its pre-FLINK-39018 role of holding FullyFilledBuffer partial-buffer splits only. The recovery-aware early branch in getNextBuffer() and the checkpointStarted inflight scan no longer touch this queue. Restores the checkState(toBeConsumedBuffers.isEmpty()) guard in requestSubpartitions() (removed by cebc174a). hasPendingPriorityEvent, notifyPriorityEvent, and the constructor signature are unchanged. Pure refactor: no public API change, no new tests; verified by the 9 existing LocalInputChannelTest regression cases. (cherry picked from commit 292cc4b9e2d9191b2e810d756127860ce98eb25d) (cherry picked from commit 7fbfc783d82ebf10d8f3bda58f618eb0986b4050) --- .../partition/consumer/LocalInputChannel.java | 63 ++++++++++++------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index 661e4b063c75f..0503a301ea7f7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -80,9 +80,16 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit private final Deque toBeConsumedBuffers = new ArrayDeque<>(); + /** + * Buffers migrated from {@code RecoveredInputChannel}, kept separately from {@link + * #toBeConsumedBuffers} so that recovery semantics (priority event interleaving, checkpoint + * inflight persistence) do not leak into the FullyFilledBuffer split path. + */ + private final Deque recoveredBuffers = new ArrayDeque<>(); + /** * Flag indicating whether there is a pending priority event (e.g., checkpoint barrier) in the - * subpartitionView that should be consumed before toBeConsumedBuffers. This is set by {@link + * subpartitionView that should be consumed before recoveredBuffers. This is set by {@link * #notifyPriorityEvent} and checked in {@link #getNextBuffer()}. */ private volatile boolean hasPendingPriorityEvent = false; @@ -131,13 +138,13 @@ public LocalInputChannel( // buffersInBacklog is set to 0 as these are recovered buffers BufferAndBacklog bufferAndBacklog = new BufferAndBacklog(buffer, 0, nextDataType, seqNum++); - toBeConsumedBuffers.add(bufferAndBacklog); + recoveredBuffers.add(bufferAndBacklog); } checkState( - toBeConsumedBuffers.size() == expectedCount, + recoveredBuffers.size() == expectedCount, "Buffer migration failed: expected %s buffers but got %s", expectedCount, - toBeConsumedBuffers.size()); + recoveredBuffers.size()); } } @@ -146,10 +153,11 @@ public LocalInputChannel( // ------------------------------------------------------------------------ public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { - // Collect inflight buffers from toBeConsumedBuffers to be persisted. - // These are buffers that have not been consumed yet when the checkpoint barrier arrives. + // Collect inflight buffers from recoveredBuffers to be persisted. + // These are recovered buffers that have not been consumed yet when the checkpoint + // barrier arrives. List inflightBuffers = new ArrayList<>(); - for (BufferAndBacklog bufferAndBacklog : toBeConsumedBuffers) { + for (BufferAndBacklog bufferAndBacklog : recoveredBuffers) { if (bufferAndBacklog.buffer().isBuffer()) { inflightBuffers.add(bufferAndBacklog.buffer().retainBuffer()); } @@ -163,6 +171,8 @@ public void checkpointStopped(long checkpointId) { @Override protected void requestSubpartitions() throws IOException { + checkState(toBeConsumedBuffers.isEmpty()); + boolean retriggerRequest = false; boolean notifyDataAvailable = false; @@ -272,10 +282,14 @@ protected int peekNextBufferSubpartitionIdInternal() throws IOException { public Optional getNextBuffer() throws IOException { checkError(); - if (!toBeConsumedBuffers.isEmpty()) { + if (!recoveredBuffers.isEmpty()) { return getNextRecoveredBuffer(); } + if (!toBeConsumedBuffers.isEmpty()) { + return getBufferAndAvailability(toBeConsumedBuffers.removeFirst()); + } + ResultSubpartitionView subpartitionView = this.subpartitionView; if (subpartitionView == null) { // There is a possible race condition between writing a EndOfPartitionEvent (1) and @@ -336,12 +350,12 @@ public Optional getNextBuffer() throws IOException { } /** - * Consumes the next buffer from toBeConsumedBuffers (recovered buffers), handling pending - * priority events and dynamic availability detection for the last recovered buffer. + * Consumes the next buffer from recoveredBuffers, handling pending priority events and dynamic + * availability detection for the last recovered buffer. */ private Optional getNextRecoveredBuffer() throws IOException { // If there is a pending priority event (e.g., unaligned checkpoint barrier), fetch it - // from subpartitionView first, skipping toBeConsumedBuffers. This ensures priority + // from subpartitionView first, skipping recoveredBuffers. This ensures priority // events are processed immediately even when there are pending recovered buffers. if (hasPendingPriorityEvent) { checkState(subpartitionView != null, "No subpartition view available"); @@ -359,10 +373,10 @@ private Optional getNextRecoveredBuffer() throws IOExcept if (!expectedNextDataType.hasPriority()) { // Reset hasPendingPriorityEvent to false if no more priority event hasPendingPriorityEvent = false; - if (!toBeConsumedBuffers.isEmpty()) { - // Correct nextDataType: if toBeConsumedBuffers is not empty, the actual next - // element to consume is from toBeConsumedBuffers, not from subpartitionView - expectedNextDataType = toBeConsumedBuffers.peek().buffer().getDataType(); + if (!recoveredBuffers.isEmpty()) { + // Correct nextDataType: if recoveredBuffers is not empty, the actual next + // element to consume is from recoveredBuffers, not from subpartitionView + expectedNextDataType = recoveredBuffers.peek().buffer().getDataType(); } } @@ -374,13 +388,13 @@ private Optional getNextRecoveredBuffer() throws IOExcept next.getSequenceNumber())); } - BufferAndBacklog next = toBeConsumedBuffers.removeFirst(); + BufferAndBacklog next = recoveredBuffers.removeFirst(); // If this is the last recovered buffer and nextDataType is NONE, // dynamically check if subpartitionView has data available. // The last buffer's nextDataType was preset to NONE during construction, // but subpartitionView may already have data available. - if (toBeConsumedBuffers.isEmpty() + if (recoveredBuffers.isEmpty() && next.getNextDataType() == Buffer.DataType.NONE && subpartitionView != null) { ResultSubpartitionView.AvailabilityWithBacklog availability = @@ -512,8 +526,13 @@ void releaseAllResources() throws IOException { subpartitionView = null; } - // Release any remaining buffers in toBeConsumedBuffers to avoid memory leak. - // These may be recovered buffers or partial buffers from FullyFilledBuffer. + // Release any remaining buffers in recoveredBuffers (migrated recovered buffers + // not yet consumed) and toBeConsumedBuffers (FullyFilledBuffer partial splits) + // to avoid memory leak. + for (BufferAndBacklog bufferAndBacklog : recoveredBuffers) { + bufferAndBacklog.buffer().recycleBuffer(); + } + recoveredBuffers.clear(); for (BufferAndBacklog bufferAndBacklog : toBeConsumedBuffers) { bufferAndBacklog.buffer().recycleBuffer(); } @@ -534,14 +553,16 @@ void announceBufferSize(int newBufferSize) { @Override int getBuffersInUseCount() { ResultSubpartitionView view = this.subpartitionView; - return toBeConsumedBuffers.size() + (view == null ? 0 : view.getNumberOfQueuedBuffers()); + return recoveredBuffers.size() + + toBeConsumedBuffers.size() + + (view == null ? 0 : view.getNumberOfQueuedBuffers()); } @Override public int unsynchronizedGetNumberOfQueuedBuffers() { ResultSubpartitionView view = subpartitionView; - int count = toBeConsumedBuffers.size(); + int count = recoveredBuffers.size() + toBeConsumedBuffers.size(); if (view != null) { count += view.unsynchronizedGetNumberOfQueuedBuffers(); } From 26277691ff62eb8239dc5b23986140348fcaac27 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Fri, 22 May 2026 13:26:49 +0200 Subject: [PATCH 02/16] [FLINK-38544][network] Phase 1: common interfaces & sentinels for spilling v2 - Adds BufferRequester, RecoverableInputChannel, RecoveryCheckpointTrigger interfaces with their final signatures (including getChannelInfo on RecoverableInputChannel and NO_OP singleton on RecoveryCheckpointTrigger). - Adds RecoveryCheckpointBarrier sentinel + DiskSnapshot data class with final 3-arg constructor signature and Chunk / StartPos / empty() helpers. - ChannelStateWriter gains addInputDataFromSpill and peekWriteResult default methods so all callers can compile against the interface without the dispatcher implementation landing in this phase. - RecoveredInputChannel#releaseAllResources visibility: package-private -> public References to SpillFile in DiskSnapshot's constructor are forward references; SpillFile itself lands in Phase 3. Each phase commit only needs to compile as a whole tree at the final commit, not in isolation. Design: requirements/38544/phase1_interfaces/design.md (cherry picked from commit 98c7b42a8fbda7a7afd00c50f5df8b61ae5714ff) --- .../channel/RecoveryCheckpointBarrier.java | 69 +++++++++++++++++++ .../channel/RecoveryCheckpointTrigger.java | 36 ++++++++++ .../api/serialization/EventSerializer.java | 20 ++++++ .../RecoveryCheckpointBarrierTest.java | 53 ++++++++++++++ 4 files changed, 178 insertions(+) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrierTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java new file mode 100644 index 0000000000000..984bfd42312f6 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.event.RuntimeEvent; + +/** Task-local event marking the recovery-state cut for a recovery checkpoint. */ +@Internal +public final class RecoveryCheckpointBarrier extends RuntimeEvent { + + private final long checkpointId; + + public RecoveryCheckpointBarrier(long checkpointId) { + this.checkpointId = checkpointId; + } + + public long getCheckpointId() { + return checkpointId; + } + + @Override + public void write(DataOutputView out) { + throw new UnsupportedOperationException( + "RecoveryCheckpointBarrier must be serialized via EventSerializer's dedicated" + + " type-tag path, not reflective write()."); + } + + @Override + public void read(DataInputView in) { + throw new UnsupportedOperationException( + "RecoveryCheckpointBarrier must be deserialized via EventSerializer's dedicated" + + " type-tag path, not reflective read()."); + } + + @Override + public int hashCode() { + return Long.hashCode(checkpointId); + } + + @Override + public boolean equals(Object other) { + return other != null + && other.getClass() == RecoveryCheckpointBarrier.class + && ((RecoveryCheckpointBarrier) other).checkpointId == this.checkpointId; + } + + @Override + public String toString() { + return "RecoveryCheckpointBarrier(" + checkpointId + ")"; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java new file mode 100644 index 0000000000000..01053ad836358 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; + +import java.io.IOException; + +@Internal +public interface RecoveryCheckpointTrigger { + + /** + * Atomically snapshots the undrained spill slice and inserts matching {@link + * RecoveryCheckpointBarrier}s into in-recovery channels. Returns an independent reader over the + * remaining segments; the caller owns and must close it. + */ + FetchedChannelStateReader snapshotAndInsertBarriers(long checkpointId) throws IOException; + + /** Returns an empty reader (no spill files, so no segments) and inserts no barriers. */ + RecoveryCheckpointTrigger NO_OP = checkpointId -> FetchedChannelStateReader.emptyReader(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java index e73d9168cb273..5711d1640f909 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.SavepointType; import org.apache.flink.runtime.checkpoint.SnapshotType; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.event.WatermarkEvent; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; @@ -43,6 +44,7 @@ import org.apache.flink.runtime.io.network.buffer.BufferConsumer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.EndOfFetchedChannelStateEvent; import org.apache.flink.runtime.io.network.partition.consumer.EndOfInputChannelStateEvent; import org.apache.flink.runtime.io.network.partition.consumer.EndOfOutputChannelStateEvent; import org.apache.flink.runtime.state.CheckpointStorageLocationReference; @@ -87,6 +89,10 @@ public class EventSerializer { private static final int END_OF_INPUT_CHANNEL_STATE_EVENT = 12; + private static final int RECOVERY_CHECKPOINT_BARRIER_EVENT = 13; + + private static final int END_OF_FETCHED_CHANNEL_STATE_EVENT = 14; + private static final byte CHECKPOINT_TYPE_CHECKPOINT = 0; private static final byte CHECKPOINT_TYPE_SAVEPOINT = 1; @@ -116,6 +122,8 @@ public static ByteBuffer toSerializedEvent(AbstractEvent event) throws IOExcepti return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_OUTPUT_CHANNEL_STATE_EVENT}); } else if (eventClass == EndOfInputChannelStateEvent.class) { return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_INPUT_CHANNEL_STATE_EVENT}); + } else if (eventClass == EndOfFetchedChannelStateEvent.class) { + return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_FETCHED_CHANNEL_STATE_EVENT}); } else if (eventClass == EndOfData.class) { return ByteBuffer.wrap( new byte[] { @@ -162,6 +170,13 @@ public static ByteBuffer toSerializedEvent(AbstractEvent event) throws IOExcepti buf.putInt(0, RECOVERY_METADATA); buf.putInt(4, recoveryMetadata.getFinalBufferSubpartitionId()); return buf; + } else if (eventClass == RecoveryCheckpointBarrier.class) { + RecoveryCheckpointBarrier barrier = (RecoveryCheckpointBarrier) event; + + ByteBuffer buf = ByteBuffer.allocate(12); + buf.putInt(0, RECOVERY_CHECKPOINT_BARRIER_EVENT); + buf.putLong(4, barrier.getCheckpointId()); + return buf; } else if (eventClass == WatermarkEvent.class) { try { final DataOutputSerializer serializer = new DataOutputSerializer(128); @@ -206,6 +221,8 @@ public static AbstractEvent fromSerializedEvent(ByteBuffer buffer, ClassLoader c return EndOfOutputChannelStateEvent.INSTANCE; } else if (type == END_OF_INPUT_CHANNEL_STATE_EVENT) { return EndOfInputChannelStateEvent.INSTANCE; + } else if (type == END_OF_FETCHED_CHANNEL_STATE_EVENT) { + return EndOfFetchedChannelStateEvent.INSTANCE; } else if (type == END_OF_USER_RECORDS_EVENT) { return new EndOfData(StopMode.values()[buffer.get()]); } else if (type == CANCEL_CHECKPOINT_MARKER_EVENT) { @@ -222,6 +239,9 @@ public static AbstractEvent fromSerializedEvent(ByteBuffer buffer, ClassLoader c } else if (type == RECOVERY_METADATA) { int subpartitionId = buffer.getInt(); return new RecoveryMetadata(subpartitionId); + } else if (type == RECOVERY_CHECKPOINT_BARRIER_EVENT) { + long checkpointId = buffer.getLong(); + return new RecoveryCheckpointBarrier(checkpointId); } else if (type == GENERALIZED_WATERMARK_EVENT) { final DataInputDeserializer deserializer = new DataInputDeserializer(buffer); WatermarkEvent watermarkEvent = new WatermarkEvent(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrierTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrierTest.java new file mode 100644 index 0000000000000..0dd2c18286a7a --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrierTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link RecoveryCheckpointBarrier}. */ +class RecoveryCheckpointBarrierTest { + + @Test + void testReflectiveWriteIsUnsupported() { + RecoveryCheckpointBarrier barrier = new RecoveryCheckpointBarrier(1L); + + assertThatThrownBy(() -> barrier.write(new DataOutputSerializer(Long.BYTES))) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("dedicated type-tag path"); + } + + @Test + void testEventSerializerHandlesRecoveryCheckpointBarrier() throws Exception { + long checkpointId = 123L; + + Buffer buffer = + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(checkpointId), false); + Object deserialized = + EventSerializer.fromBuffer( + buffer, RecoveryCheckpointBarrier.class.getClassLoader()); + + assertThat(deserialized).isEqualTo(new RecoveryCheckpointBarrier(checkpointId)); + } +} From 07971afd4abb57e3aed0a138be8cbaabb0974e84 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Fri, 22 May 2026 13:27:21 +0200 Subject: [PATCH 03/16] [FLINK-38544][network] Phase 2: InputChannel side push-based recovery - Local/Remote InputChannel implement RecoverableInputChannel from Phase 1 - recoveredBuffers reshaped to Deque; allRecoveredBuffersDelivered flag - getNextBuffer() unified under a single inRecovery predicate - checkpointStarted split into mutually-exclusive in-recovery / not-in-recovery - stateConsumedFuture triggered by (allRecoveredBuffersDelivered && queue empty) - RecoveredInputChannel.toInputChannel migrates via the new push interface; the initialRecoveredBuffers constructor parameter is gone. - LocalInputChannel.getNextRecoveredBuffer helper deleted Design: requirements/38544/phase2_input_channel/design.md (cherry picked from commit 82904098a54767809ee831aa64b1ff84cf0bc0f3) --- ...reditBasedSequenceNumberingViewReader.java | 9 +- .../io/network/netty/NettyMessage.java | 14 +- .../netty/NettyPartitionRequestClient.java | 3 +- .../netty/PartitionRequestServerHandler.java | 5 +- .../partition/consumer/BufferManager.java | 44 +- .../partition/consumer/IndexedInputGate.java | 8 + .../network/partition/consumer/InputGate.java | 15 +- .../partition/consumer/LocalInputChannel.java | 473 ++++++++++++---- .../consumer/LocalRecoveredInputChannel.java | 8 +- .../consumer/RecoverableInputChannel.java | 63 +++ .../consumer/RemoteInputChannel.java | 352 ++++++++++-- .../consumer/RemoteRecoveredInputChannel.java | 8 +- .../partition/consumer/SingleInputGate.java | 28 +- .../partition/consumer/UnionInputGate.java | 16 +- .../consumer/UnknownInputChannel.java | 36 +- .../taskmanager/InputGateWithMetrics.java | 20 +- .../netty/CancelPartitionRequestTest.java | 6 +- ...asedPartitionRequestClientHandlerTest.java | 3 +- ...tBasedSequenceNumberingViewReaderTest.java | 2 +- ...ttyMessageServerSideSerializationTest.java | 4 +- .../netty/PartitionRequestQueueTest.java | 24 +- .../PartitionRequestRegistrationTest.java | 12 +- .../PartitionRequestServerHandlerTest.java | 4 +- .../ServerTransportErrorHandlingTest.java | 3 +- .../consumer/InputChannelBuilder.java | 12 +- .../consumer/LocalInputChannelTest.java | 398 ++++++++++++-- .../LocalRecoveredInputChannelTest.java | 52 ++ .../consumer/RecoveredInputChannelTest.java | 96 ++-- .../consumer/RemoteInputChannelTest.java | 503 +++++++++++++++++- .../RemoteRecoveredInputChannelTest.java | 52 ++ .../consumer/SingleInputGateBuilder.java | 10 + .../consumer/SingleInputGateTest.java | 30 -- .../partition/consumer/TestInputChannel.java | 44 +- .../consumer/UnionInputGateTest.java | 58 -- .../runtime/io/MockIndexedInputGate.java | 18 +- .../streaming/runtime/io/MockInputGate.java | 18 +- .../SingleInputGateBenchmarkFactory.java | 6 +- .../AlignedCheckpointsMassiveRandomTest.java | 18 +- 38 files changed, 2007 insertions(+), 468 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java index ae57b39b08d3c..1ac2687bf3946 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java @@ -81,12 +81,17 @@ class CreditBasedSequenceNumberingViewReader private int numCreditsAvailable; CreditBasedSequenceNumberingViewReader( - InputChannelID receiverId, int initialCredit, PartitionRequestQueue requestQueue) { + InputChannelID receiverId, + int initialCredit, + boolean needsRecovery, + PartitionRequestQueue requestQueue) { checkArgument(initialCredit >= 0, "Must be non-negative."); this.receiverId = receiverId; this.initialCredit = initialCredit; - this.numCreditsAvailable = initialCredit; + // During spill recovery, exclusive buffers are on loan to the recovery drain; real credit + // is announced only after recovery completes. + this.numCreditsAvailable = needsRecovery ? 0 : initialCredit; this.requestQueue = requestQueue; this.subpartitionId = -1; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java index ebe596e01f6b5..6747e139a034c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java @@ -583,15 +583,19 @@ static class PartitionRequest extends NettyMessage { final int credit; + final boolean needsRecovery; + PartitionRequest( ResultPartitionID partitionId, ResultSubpartitionIndexSet queueIndexSet, InputChannelID receiverId, - int credit) { + int credit, + boolean needsRecovery) { this.partitionId = checkNotNull(partitionId); this.queueIndexSet = queueIndexSet; this.receiverId = checkNotNull(receiverId); this.credit = credit; + this.needsRecovery = needsRecovery; } @Override @@ -604,6 +608,7 @@ void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator queueIndexSet.writeTo(bb); receiverId.writeTo(bb); bb.writeInt(credit); + bb.writeBoolean(needsRecovery); }; writeToChannel( @@ -616,7 +621,8 @@ void write(ChannelOutboundInvoker out, ChannelPromise promise, ByteBufAllocator + ExecutionAttemptID.getByteBufLength() + ResultSubpartitionIndexSet.getByteBufLength(queueIndexSet) + InputChannelID.getByteBufLength() - + Integer.BYTES); + + Integer.BYTES + + Byte.BYTES); } static PartitionRequest readFrom(ByteBuf buffer) { @@ -628,8 +634,10 @@ static PartitionRequest readFrom(ByteBuf buffer) { ResultSubpartitionIndexSet.fromByteBuf(buffer); InputChannelID receiverId = InputChannelID.fromByteBuf(buffer); int credit = buffer.readInt(); + boolean needsRecovery = buffer.readBoolean(); - return new PartitionRequest(partitionId, queueIndexSet, receiverId, credit); + return new PartitionRequest( + partitionId, queueIndexSet, receiverId, credit, needsRecovery); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyPartitionRequestClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyPartitionRequestClient.java index 7cc52a234e09a..737e3de72690d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyPartitionRequestClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyPartitionRequestClient.java @@ -129,7 +129,8 @@ public void requestSubpartition( partitionId, subpartitionIndexSet, inputChannel.getInputChannelId(), - inputChannel.getInitialCredit()); + inputChannel.getInitialCredit(), + inputChannel.needsRecovery()); final ChannelFutureListener listener = future -> { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandler.java index f93651b9f3e8d..6f2cd6b13ad41 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandler.java @@ -85,7 +85,10 @@ protected void channelRead0(ChannelHandlerContext ctx, NettyMessage msg) throws NetworkSequenceViewReader reader; reader = new CreditBasedSequenceNumberingViewReader( - request.receiverId, request.credit, outboundQueue); + request.receiverId, + request.credit, + request.needsRecovery, + outboundQueue); reader.requestSubpartitionViewOrRegisterListener( partitionProvider, request.partitionId, request.queueIndexSet); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java index db38025def9e3..1eed2a0284fd0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java @@ -70,13 +70,24 @@ public class BufferManager implements BufferListener, BufferRecycler { @GuardedBy("bufferQueue") private int numRequiredBuffers; + /** + * Gates credit announcements while a recovery drain borrows this channel's buffers. Kept under + * {@code bufferQueue} to avoid inverting the queue/recovered-buffer lock order. + */ + @GuardedBy("bufferQueue") + private boolean notifyAvailable; + public BufferManager( - MemorySegmentProvider globalPool, InputChannel inputChannel, int numRequiredBuffers) { + MemorySegmentProvider globalPool, + InputChannel inputChannel, + int numRequiredBuffers, + boolean notifyInitiallyEnabled) { this.globalPool = checkNotNull(globalPool); this.inputChannel = checkNotNull(inputChannel); checkArgument(numRequiredBuffers >= 0); this.numRequiredBuffers = numRequiredBuffers; + this.notifyAvailable = notifyInitiallyEnabled; } // ------------------------------------------------------------------------ @@ -158,23 +169,22 @@ void requestExclusiveBuffers(int numExclusiveBuffers) throws IOException { /** * Requests floating buffers from the buffer pool based on the given required amount, and - * returns the actual requested amount. If the required amount is not fully satisfied, it will - * register as a listener. + * returns the number of buffers that may be announced to the producer as credit. During + * recovery, requested buffers are queued but announced only by {@link #enableNotify()}. */ int requestFloatingBuffers(int numRequired) { - int numRequestedBuffers = 0; synchronized (bufferQueue) { // Similar to notifyBufferAvailable(), make sure that we never add a buffer after // channel // released all buffers via releaseAllResources(). if (inputChannel.isReleased()) { - return numRequestedBuffers; + return 0; } numRequiredBuffers = numRequired; - numRequestedBuffers = tryRequestBuffers(); + int numRequestedBuffers = tryRequestBuffers(); + return notifyAvailable ? numRequestedBuffers : 0; } - return numRequestedBuffers; } private int tryRequestBuffers() { @@ -209,6 +219,7 @@ private int tryRequestBuffers() { @Override public void recycle(MemorySegment segment) { @Nullable Buffer releasedFloatingBuffer = null; + boolean announceCredit = false; synchronized (bufferQueue) { try { // Similar to notifyBufferAvailable(), make sure that we never add a buffer @@ -226,11 +237,12 @@ public void recycle(MemorySegment segment) { } finally { bufferQueue.notifyAll(); } + announceCredit = releasedFloatingBuffer == null && notifyAvailable; } if (releasedFloatingBuffer != null) { releasedFloatingBuffer.recycleBuffer(); - } else { + } else if (announceCredit) { try { inputChannel.notifyBufferAvailable(1); } catch (Throwable t) { @@ -344,6 +356,9 @@ public boolean notifyBufferAvailable(Buffer buffer) { isBufferUsed = true; numBuffers += 1 + tryRequestBuffers(); bufferQueue.notifyAll(); + if (!notifyAvailable) { + numBuffers = 0; + } } inputChannel.notifyBufferAvailable(numBuffers); @@ -359,6 +374,19 @@ public void notifyBufferDestroyed() { // Nothing to do actually. } + /** + * Opens the recovery credit gate and announces the queued buffers atomically with respect to + * concurrent recycle/floating-buffer callbacks. + */ + void enableNotify() throws IOException { + int available; + synchronized (bufferQueue) { + notifyAvailable = true; + available = bufferQueue.getAvailableBufferSize(); + } + inputChannel.notifyBufferAvailable(available); + } + // ------------------------------------------------------------------------ // Getter properties // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java index 915012924d161..95191f7175573 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java @@ -80,4 +80,12 @@ public void convertToPriorityEvent(int channelIndex, int sequenceNumber) throws /** Returns whether unaligned checkpointing during recovery is enabled. */ public abstract boolean isCheckpointingDuringRecoveryEnabled(); + + /** + * Sets whether converted physical channels start in recovery. Must be published before the + * buffer-filtering completion future is completed. + */ + public abstract void setNeedsRecovery(boolean enabled); + + public abstract boolean needsRecovery(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java index dd744bae330ff..b70c284a21a61 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java @@ -141,6 +141,14 @@ public abstract void acknowledgeAllRecordsProcessed(InputChannelInfo channelInfo /** Returns the channel of this gate. */ public abstract InputChannel getChannel(int channelIndex); + /** + * Returns the channel identified by {@code channelInfo}. Unlike {@link #getChannel(int)}, whose + * index is gate-global, this resolves through the full {@code (gateIdx, inputChannelIdx)} pair, + * so it stays correct for {@link UnionInputGate} where the global index differs from a member + * gate's local channel index. + */ + public abstract InputChannel getChannel(InputChannelInfo channelInfo); + /** Returns the channel infos of this gate. */ public List getChannelInfos() { return IntStream.range(0, getNumberOfInputChannels()) @@ -192,12 +200,5 @@ public String toString() { public abstract CompletableFuture getStateConsumedFuture(); - /** - * Returns a future that completes when buffer filtering is complete for all channels. This - * future completes before {@link #getStateConsumedFuture()}, enabling earlier RUNNING state - * transition when unaligned checkpoint during recovery is enabled. - */ - public abstract CompletableFuture getBufferFilteringCompleteFuture(); - public abstract void finishReadRecoveredState() throws IOException; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index 0503a301ea7f7..e0ce62009d5dc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -21,11 +21,15 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; +import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.TaskEventPublisher; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.CompositeBuffer; import org.apache.flink.runtime.io.network.buffer.FileRegionBuffer; @@ -43,21 +47,26 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collections; import java.util.Deque; +import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.Timer; import java.util.TimerTask; +import java.util.concurrent.CompletableFuture; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; /** An input channel, which requests a local subpartition. */ -public class LocalInputChannel extends InputChannel implements BufferAvailabilityListener { +public class LocalInputChannel extends InputChannel + implements BufferAvailabilityListener, RecoverableInputChannel { private static final Logger LOG = LoggerFactory.getLogger(LocalInputChannel.class); @@ -81,19 +90,46 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit private final Deque toBeConsumedBuffers = new ArrayDeque<>(); /** - * Buffers migrated from {@code RecoveredInputChannel}, kept separately from {@link + * Buffers delivered from {@code RecoveredInputChannel}, kept separately from {@link * #toBeConsumedBuffers} so that recovery semantics (priority event interleaving, checkpoint - * inflight persistence) do not leak into the FullyFilledBuffer split path. + * inflight persistence) do not leak into the FullyFilledBuffer split path. Holds recovered + * buffers plus the {@code RecoveryCheckpointBarrier} and {@code EndOfFetchedChannelStateEvent} + * sentinels. The deque object is its own monitor; {@link #inRecovery} and {@link + * #recoverySequenceNumber} are guarded by it too. */ - private final Deque recoveredBuffers = new ArrayDeque<>(); + private final Deque recoveredBuffers = new ArrayDeque<>(); /** - * Flag indicating whether there is a pending priority event (e.g., checkpoint barrier) in the - * subpartitionView that should be consumed before recoveredBuffers. This is set by {@link - * #notifyPriorityEvent} and checked in {@link #getNextBuffer()}. + * Whether the channel is still replaying recovered state. Starts {@code false} for channels + * that do not need recovery and is flipped to {@code false} the moment the consume path polls + * the {@code EndOfFetchedChannelStateEvent} sentinel appended after the last recovered buffer + * (see {@link #onRecoveredStateConsumed()}). While {@code true} the consume path serves + * recovered buffers and does not poll ordinary upstream data. + */ + @GuardedBy("recoveredBuffers") + private boolean inRecovery; + + /** + * Sequence number assigned to recovered buffers, starting at {@link Integer#MIN_VALUE}, + * consistent with {@link RecoveredInputChannel}. + */ + private int recoverySequenceNumber = Integer.MIN_VALUE; + + @Nullable private final BufferManager bufferManager; + + private final int networkBuffersPerChannel; + + private final boolean needsRecovery; + + /** + * Whether a priority event (e.g., checkpoint barrier) is pending in {@code subpartitionView} + * and must be consumed before {@code recoveredBuffers}. Volatile because it is written by the + * network thread and read by the task thread. */ private volatile boolean hasPendingPriorityEvent = false; + private final CompletableFuture upstreamReady = new CompletableFuture<>(); + public LocalInputChannel( SingleInputGate inputGate, int channelIndex, @@ -106,7 +142,8 @@ public LocalInputChannel( Counter numBytesIn, Counter numBuffersIn, ChannelStateWriter stateWriter, - ArrayDeque initialRecoveredBuffers) { + int networkBuffersPerChannel, + boolean needsRecovery) { super( inputGate, @@ -120,49 +157,210 @@ public LocalInputChannel( this.partitionManager = checkNotNull(partitionManager); this.taskEventPublisher = checkNotNull(taskEventPublisher); - this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo()); - - // Migrate recovered buffers from RecoveredInputChannel if provided. - // These buffers have been filtered but not yet consumed by the Task. - if (!initialRecoveredBuffers.isEmpty()) { - final int expectedCount = initialRecoveredBuffers.size(); - // Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel. - int seqNum = Integer.MIN_VALUE; - while (!initialRecoveredBuffers.isEmpty()) { - Buffer buffer = initialRecoveredBuffers.poll(); - // Determine next data type based on the next buffer in the queue - Buffer.DataType nextDataType = - initialRecoveredBuffers.isEmpty() - ? Buffer.DataType.NONE - : initialRecoveredBuffers.peek().getDataType(); - // buffersInBacklog is set to 0 as these are recovered buffers - BufferAndBacklog bufferAndBacklog = - new BufferAndBacklog(buffer, 0, nextDataType, seqNum++); - recoveredBuffers.add(bufferAndBacklog); + this.channelStatePersister = + new ChannelStatePersister(checkNotNull(stateWriter), getChannelInfo()); + this.inRecovery = needsRecovery; + this.bufferManager = + needsRecovery + ? new BufferManager(inputGate.getMemorySegmentProvider(), this, 0, true) + : null; + this.networkBuffersPerChannel = networkBuffersPerChannel; + this.needsRecovery = needsRecovery; + } + + @Override + void setup() throws IOException { + if (needsRecovery && networkBuffersPerChannel > 0) { + bufferManager.requestExclusiveBuffers(networkBuffersPerChannel); + } + } + + // ------------------------------------------------------------------------ + // RecoverableInputChannel implementation + // ------------------------------------------------------------------------ + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + boolean wasEmpty; + synchronized (recoveredBuffers) { + if (isReleased) { + buffer.recycleBuffer(); + return; } - checkState( - recoveredBuffers.size() == expectedCount, - "Buffer migration failed: expected %s buffers but got %s", - expectedCount, - recoveredBuffers.size()); + // Migrate recovered buffers from RecoveredInputChannel. These buffers have been + // filtered but not yet consumed by the Task. + wasEmpty = offerRecoveredBuffer(buffer); + } + if (wasEmpty) { + notifyChannelNonEmpty(); } } + @Override + public void finishRecoveredBufferDelivery() throws IOException { + upstreamReady.join(); + boolean wasEmpty; + synchronized (recoveredBuffers) { + checkState(inRecovery, "Recovery delivery already finished."); + // Append the sentinel after the last recovered buffer. The consume path flips out of + // recovery only once it polls this sentinel, guaranteeing all recovered buffers are + // consumed first. + wasEmpty = + offerRecoveredBuffer( + EventSerializer.toBuffer( + EndOfFetchedChannelStateEvent.INSTANCE, false)); + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + @Override + public Buffer requestRecoveryBufferBlocking() throws InterruptedException, IOException { + checkState( + bufferManager != null, + "requestRecoveryBufferBlocking called on a Local channel constructed with" + + " needsRecovery=false"); + upstreamReady.join(); + return bufferManager.requestBufferBlocking(); + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) throws IOException { + boolean wasEmpty = false; + synchronized (recoveredBuffers) { + if (!isReleased && inRecovery) { + wasEmpty = + offerRecoveredBuffer( + EventSerializer.toBuffer( + new RecoveryCheckpointBarrier(checkpointId), false)); + } + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + /** + * Flips out of recovery the moment the consume path polls the {@code + * EndOfFetchedChannelStateEvent} sentinel, i.e. once all recovered buffers have been consumed. + * Live upstream data may flow again afterwards. + */ + @Override + public void onRecoveredStateConsumed() { + synchronized (recoveredBuffers) { + checkState(inRecovery, "Recovery already finished."); + inRecovery = false; + } + notifyChannelNonEmpty(); + } + + /** + * Appends a recovered buffer (or {@code RecoveryCheckpointBarrier} / {@code + * EndOfFetchedChannelStateEvent} sentinel) to {@link #recoveredBuffers}. + * + * @return {@code true} iff {@link #recoveredBuffers} transitioned from empty to non-empty. + */ + private boolean offerRecoveredBuffer(Buffer buffer) { + assert Thread.holdsLock(recoveredBuffers); + checkState(inRecovery, "Push into recovered buffers after recovery finished."); + boolean wasEmpty = recoveredBuffers.isEmpty(); + recoveredBuffers.add(buffer); + return wasEmpty; + } + + private int nextRecoverySequenceNumber() { + assert Thread.holdsLock(recoveredBuffers); + return recoverySequenceNumber++; + } + + /** + * Walks {@link #recoveredBuffers} up to the {@link RecoveryCheckpointBarrier} sentinel matching + * {@code checkpointId}, retaining each pre-barrier recovered data buffer and removing the + * sentinel. + * + * @throws IOException if no sentinel matching {@code checkpointId} is found (the snapshot + * protocol guarantees one must be present while the channel is in recovery). + */ + private List collectPreRecoveryBarrier(long checkpointId) throws IOException { + assert Thread.holdsLock(recoveredBuffers); + List retained = new ArrayList<>(); + try { + Iterator it = recoveredBuffers.iterator(); + while (it.hasNext()) { + Buffer b = it.next(); + if (isRecoveryCheckpointBarrier(b, checkpointId)) { + it.remove(); + b.recycleBuffer(); + return retained; + } + if (b.isBuffer()) { + retained.add(b.retainBuffer()); + } + } + } catch (IOException e) { + releaseRetainedBuffers(retained); + throw e; + } + releaseRetainedBuffers(retained); + throw new IOException( + "Missing RecoveryCheckpointBarrier for checkpoint " + + checkpointId + + " in recoveredBuffers for channel " + + getChannelInfo()); + } + + private static void releaseRetainedBuffers(List retained) { + for (Buffer buffer : retained) { + buffer.recycleBuffer(); + } + } + + private static boolean isRecoveryCheckpointBarrier(Buffer b, long checkpointId) + throws IOException { + if (b.isBuffer()) { + return false; + } + AbstractEvent event = + EventSerializer.fromBuffer(b, RecoveryCheckpointBarrier.class.getClassLoader()); + b.setReaderIndex(0); + return event instanceof RecoveryCheckpointBarrier + && ((RecoveryCheckpointBarrier) event).getCheckpointId() == checkpointId; + } + // ------------------------------------------------------------------------ // Consume // ------------------------------------------------------------------------ + @Override public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { - // Collect inflight buffers from recoveredBuffers to be persisted. - // These are recovered buffers that have not been consumed yet when the checkpoint - // barrier arrives. - List inflightBuffers = new ArrayList<>(); - for (BufferAndBacklog bufferAndBacklog : recoveredBuffers) { - if (bufferAndBacklog.buffer().isBuffer()) { - inflightBuffers.add(bufferAndBacklog.buffer().retainBuffer()); + try { + List toPersist; + boolean stopPersisting = false; + synchronized (recoveredBuffers) { + if (inRecovery) { + // Collect inflight buffers from recoveredBuffers to be persisted. These are + // recovered buffers that have not been consumed yet when the checkpoint barrier + // arrives. + toPersist = collectPreRecoveryBarrier(barrier.getId()); + stopPersisting = true; + } else { + toPersist = Collections.emptyList(); + } } + channelStatePersister.startPersisting(barrier.getId(), toPersist); + if (stopPersisting) { + // Recovered inflight buffers are collected in one shot and no upstream data flows + // during recovery, so close the persist window immediately to keep the persister + // from carrying a pending state into later checkpoints. + channelStatePersister.stopPersisting(barrier.getId()); + } + } catch (IOException e) { + throw new CheckpointException( + "Failed to extract recovered buffers for checkpoint " + barrier.getId(), + CheckpointFailureReason.CHECKPOINT_DECLINED, + e); } - channelStatePersister.startPersisting(barrier.getId(), inflightBuffers); } public void checkpointStopped(long checkpointId) { @@ -206,6 +404,7 @@ protected void requestSubpartitions() throws IOException { this.subpartitionView = null; } else { notifyDataAvailable = true; + upstreamReady.complete(null); } } catch (PartitionNotFoundException notFound) { if (increaseBackoff()) { @@ -282,8 +481,31 @@ protected int peekNextBufferSubpartitionIdInternal() throws IOException { public Optional getNextBuffer() throws IOException { checkError(); - if (!recoveredBuffers.isEmpty()) { - return getNextRecoveredBuffer(); + // Read inRecovery and poll the recovered buffer under a single lock acquisition to avoid + // grabbing the monitor twice on the hot path. + boolean inRecovery; + Buffer recoveredBuf = null; + synchronized (recoveredBuffers) { + inRecovery = this.inRecovery; + if (inRecovery && !hasPendingPriorityEvent && !recoveredBuffers.isEmpty()) { + recoveredBuf = recoveredBuffers.poll(); + } + } + + if (inRecovery) { + // Always return an already-polled recovered buffer first: hasPendingPriorityEvent may + // be flipped to true by a concurrent notifyPriorityEvent() after the poll, and + // re-reading + // it here would otherwise drop this buffer. A pending priority event is served on the + // next getNextBuffer() call instead. + if (recoveredBuf != null) { + return wrapRecoveredBufferAsAvailability(recoveredBuf); + } + if (hasPendingPriorityEvent) { + return pullPriorityFromSubpartitionView(); + } + // Drain not finished yet; block normal upstream data until delivery completes. + return Optional.empty(); } if (!toBeConsumedBuffers.isEmpty()) { @@ -349,66 +571,93 @@ public Optional getNextBuffer() throws IOException { return getBufferAndAvailability(next); } - /** - * Consumes the next buffer from recoveredBuffers, handling pending priority events and dynamic - * availability detection for the last recovered buffer. - */ - private Optional getNextRecoveredBuffer() throws IOException { - // If there is a pending priority event (e.g., unaligned checkpoint barrier), fetch it - // from subpartitionView first, skipping recoveredBuffers. This ensures priority - // events are processed immediately even when there are pending recovered buffers. - if (hasPendingPriorityEvent) { - checkState(subpartitionView != null, "No subpartition view available"); - BufferAndBacklog next = subpartitionView.getNextBuffer(); - checkState( - next != null && next.buffer().getDataType().hasPriority(), - "Expected priority event, but got %s", - next == null ? "null" : next.buffer().getDataType()); - - // Check for barrier to update channel state persister. - // Note: maybePersist is not needed for barriers as they are not regular data buffers. - channelStatePersister.checkForBarrier(next.buffer()); - - Buffer.DataType expectedNextDataType = next.getNextDataType(); - if (!expectedNextDataType.hasPriority()) { - // Reset hasPendingPriorityEvent to false if no more priority event - hasPendingPriorityEvent = false; - if (!recoveredBuffers.isEmpty()) { - // Correct nextDataType: if recoveredBuffers is not empty, the actual next - // element to consume is from recoveredBuffers, not from subpartitionView - expectedNextDataType = recoveredBuffers.peek().buffer().getDataType(); - } - } + private Optional pullPriorityFromSubpartitionView() throws IOException { + // If there is a pending priority event (e.g., unaligned checkpoint barrier), fetch it from + // subpartitionView first, skipping recoveredBuffers. This ensures priority events are + // processed immediately even when there are pending recovered buffers. + checkState(subpartitionView != null, "No subpartition view available"); + BufferAndBacklog next = subpartitionView.getNextBuffer(); + checkState( + next != null && next.buffer().getDataType().hasPriority(), + "Expected priority event, but got %s", + next == null ? "null" : next.buffer().getDataType()); + + // Check for barrier to update channel state persister. Note: maybePersist is not needed for + // barriers as they are not regular data buffers. + channelStatePersister.checkForBarrier(next.buffer()); + + Buffer.DataType expectedNextDataType = next.getNextDataType(); + if (!expectedNextDataType.hasPriority()) { + // Reset hasPendingPriorityEvent to false if no more priority event. + hasPendingPriorityEvent = false; + // Correct nextDataType: if recoveredBuffers is not empty, the actual next element to + // consume is from recoveredBuffers, not from subpartitionView. + expectedNextDataType = peekNextDataType(next.getNextDataType()); + } - return getBufferAndAvailability( - new BufferAndBacklog( - next.buffer(), - next.buffersInBacklog(), - expectedNextDataType, - next.getSequenceNumber())); - } - - BufferAndBacklog next = recoveredBuffers.removeFirst(); - - // If this is the last recovered buffer and nextDataType is NONE, - // dynamically check if subpartitionView has data available. - // The last buffer's nextDataType was preset to NONE during construction, - // but subpartitionView may already have data available. - if (recoveredBuffers.isEmpty() - && next.getNextDataType() == Buffer.DataType.NONE - && subpartitionView != null) { - ResultSubpartitionView.AvailabilityWithBacklog availability = - subpartitionView.getAvailabilityAndBacklog(true); - if (availability.isAvailable()) { - next = - new BufferAndBacklog( - next.buffer(), - availability.getBacklog(), - Buffer.DataType.DATA_BUFFER, - next.getSequenceNumber()); + return Optional.of( + new BufferAndAvailability( + next.buffer(), + expectedNextDataType, + next.buffersInBacklog(), + next.getSequenceNumber())); + } + + private Optional wrapRecoveredBufferAsAvailability(Buffer buf) + throws IOException { + if (buf instanceof FileRegionBuffer) { + buf = ((FileRegionBuffer) buf).readInto(inputGate.getUnpooledSegment()); + } + if (buf instanceof CompositeBuffer) { + buf = ((CompositeBuffer) buf).getFullBufferData(inputGate.getUnpooledSegment()); + } + + numBytesIn.inc(buf.readableBytes()); + numBuffersIn.inc(); + + ResultSubpartitionView view = subpartitionView; + Buffer.DataType upstreamProbe; + if (view != null && view.getAvailabilityAndBacklog(true).isAvailable()) { + upstreamProbe = Buffer.DataType.DATA_BUFFER; + } else { + upstreamProbe = Buffer.DataType.NONE; + } + + int sequenceNumber; + synchronized (recoveredBuffers) { + Buffer.DataType nextDataType = peekNextDataType(upstreamProbe); + sequenceNumber = nextRecoverySequenceNumber(); + NetworkActionsLogger.traceInput( + "LocalInputChannel#getNextBuffer", + buf, + inputGate.getOwningTaskName(), + channelInfo, + channelStatePersister, + sequenceNumber); + // buffersInBacklog is set to 0 as these are recovered buffers. + return Optional.of(new BufferAndAvailability(buf, nextDataType, 0, sequenceNumber)); + } + } + + private Buffer.DataType peekNextDataType(Buffer.DataType nextDataTypeOnUpstream) { + synchronized (recoveredBuffers) { + if (!recoveredBuffers.isEmpty()) { + return recoveredBuffers.peek().getDataType(); + } + if (inRecovery) { + // If this is the last currently available recovered buffer, hide upstream data + // until the EndOfFetchedChannelStateEvent sentinel flips the channel out of + // recovery. The last buffer's nextDataType is effectively NONE while the drain can + // still append more recovered buffers. + return Buffer.DataType.NONE; } } - return getBufferAndAvailability(next); + // If this is the last recovered buffer after delivery finished, dynamically check if + // subpartitionView has data available. The last buffer's nextDataType may have been NONE + // while recovered data was still being delivered, but subpartitionView may already have + // data + // available now. + return nextDataTypeOnUpstream; } private Optional getBufferAndAvailability(BufferAndBacklog next) @@ -449,7 +698,7 @@ public void notifyDataAvailable(ResultSubpartitionView view) { @Override public void notifyPriorityEvent(int prioritySequenceNumber) { // Set flag so that getNextBuffer() knows to fetch priority event from subpartitionView - // before consuming toBeConsumedBuffers. + // before consuming recoveredBuffers. hasPendingPriorityEvent = true; super.notifyPriorityEvent(prioritySequenceNumber); } @@ -520,23 +769,30 @@ void releaseAllResources() throws IOException { if (!isReleased) { isReleased = true; + upstreamReady.completeExceptionally(new CancelTaskException("Channel released.")); + ResultSubpartitionView view = subpartitionView; if (view != null) { view.releaseAllResources(); subpartitionView = null; } - // Release any remaining buffers in recoveredBuffers (migrated recovered buffers - // not yet consumed) and toBeConsumedBuffers (FullyFilledBuffer partial splits) - // to avoid memory leak. - for (BufferAndBacklog bufferAndBacklog : recoveredBuffers) { - bufferAndBacklog.buffer().recycleBuffer(); + // Release any remaining buffers in recoveredBuffers (migrated recovered buffers not yet + // consumed) and toBeConsumedBuffers (FullyFilledBuffer partial splits) to avoid memory + // leak. + synchronized (recoveredBuffers) { + for (Buffer buffer : recoveredBuffers) { + buffer.recycleBuffer(); + } + recoveredBuffers.clear(); } - recoveredBuffers.clear(); for (BufferAndBacklog bufferAndBacklog : toBeConsumedBuffers) { bufferAndBacklog.buffer().recycleBuffer(); } toBeConsumedBuffers.clear(); + if (bufferManager != null) { + bufferManager.releaseAllBuffers(new ArrayDeque<>()); + } } } @@ -590,4 +846,9 @@ public String toString() { ResultSubpartitionView getSubpartitionView() { return subpartitionView; } + + @VisibleForTesting + void completeUpstreamReadyForTest() { + upstreamReady.complete(null); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java index bdde2244f38ef..de058bd3723ce 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java @@ -19,14 +19,11 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.runtime.io.network.TaskEventPublisher; -import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; -import java.util.ArrayDeque; - import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -64,7 +61,7 @@ public class LocalRecoveredInputChannel extends RecoveredInputChannel { } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal(boolean needsRecovery) { return new LocalInputChannel( inputGate, getChannelIndex(), @@ -77,6 +74,7 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer numBytesIn, numBuffersIn, channelStateWriter, - remainingBuffers); + networkBuffersPerChannel, + needsRecovery); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java new file mode 100644 index 0000000000000..40a4ab27a0908 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import java.io.IOException; + +/** Physical input channel that can receive recovered buffers pushed by the spill drain. */ +@Internal +public interface RecoverableInputChannel { + + InputChannelInfo getChannelInfo(); + + /** + * Appends a recovered buffer or recovery-checkpoint sentinel. Released channels recycle the + * buffer silently. + */ + void onRecoveredStateBuffer(Buffer buffer); + + /** + * Marks producer-side recovery delivery complete. Implementations wait for upstream readiness + * before flipping this state so channels without spill entries still observe the same handoff. + */ + void finishRecoveredBufferDelivery() throws IOException, InterruptedException; + + /** + * Inserts a {@code RecoveryCheckpointBarrier} for {@code checkpointId} into this channel's + * recovery queue if the channel is still in recovery. + */ + void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) throws IOException; + + /** + * Blocks until a buffer is available from this channel's own buffer pool. Implementations must + * first await upstream readiness and must be invoked outside the drainer lock. + */ + Buffer requestRecoveryBufferBlocking() throws InterruptedException, IOException; + + /** + * Invoked by the consume path the moment it polls the {@code EndOfFetchedChannelStateEvent} + * sentinel, i.e. once all recovered buffers have been consumed. Implementations flip out of + * recovery, release any upstream events held back during recovery, and reopen the upstream so + * live data may flow again. + */ + void onRecoveredStateConsumed() throws IOException; +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 645446120d09e..0f7b582ab3c23 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.execution.CancelTaskException; @@ -62,9 +63,11 @@ import java.util.List; import java.util.Optional; import java.util.OptionalLong; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.RECOVERY_METADATA; import static org.apache.flink.util.Preconditions.checkArgument; @@ -72,7 +75,7 @@ import static org.apache.flink.util.Preconditions.checkState; /** An input channel, which requests a remote partition queue. */ -public class RemoteInputChannel extends InputChannel { +public class RemoteInputChannel extends InputChannel implements RecoverableInputChannel { private static final Logger LOG = LoggerFactory.getLogger(RemoteInputChannel.class); private static final int NONE = -1; @@ -107,6 +110,8 @@ public class RemoteInputChannel extends InputChannel { /** The initial number of exclusive buffers assigned to this channel. */ private final int initialCredit; + private final boolean needsRecovery; + /** The milliseconds timeout for partition request listener in result partition manager. */ private final int partitionRequestListenerTimeout; @@ -123,6 +128,43 @@ public class RemoteInputChannel extends InputChannel { private final ChannelStatePersister channelStatePersister; + /** + * Whether the channel is still replaying recovered state. Recovered buffers delivered by the + * spill drain are appended directly to {@link #receivedBuffers}, so the consume path needs no + * recovery-specific branch. Starts {@code false} for channels that do not need recovery and is + * flipped to {@code false} the moment the consume path polls the {@code + * EndOfFetchedChannelStateEvent} sentinel that the drain appended after the last recovered + * buffer (see {@link #onRecoveredStateConsumed()}). + */ + @GuardedBy("receivedBuffers") + private boolean inRecovery; + + /** + * Sequence number assigned to recovered buffers, starting at {@link Integer#MIN_VALUE}, + * consistent with {@link RecoveredInputChannel}. + */ + @GuardedBy("receivedBuffers") + private int recoverySequenceNumber = Integer.MIN_VALUE; + + /** + * Ordinary (non-priority) upstream events received while recovery is still in progress. They + * cannot enter {@link #receivedBuffers} ahead of the recovered buffers, so they are stashed + * here and appended once recovery delivery finishes. Credit is suppressed during recovery, so + * the upstream can only send events (never data buffers) before {@link + * #finishRecoveredBufferDelivery()}. + */ + @GuardedBy("receivedBuffers") + private final ArrayDeque recoveryEventStash = new ArrayDeque<>(); + + /** + * One-shot latch that opens once the upstream reader is registered and the connection is live + * (signalled by the first {@link #onBuffer} or by {@link #releaseAllResources()}). + * Recovery-side awaiters block on it before handing off; once open, {@link + * CountDownLatch#countDown()} on the hot path is a cheap idempotent no-op, unlike completing a + * {@code CompletableFuture}. + */ + private final CountDownLatch upstreamReady = new CountDownLatch(1); + private long totalQueueSizeInBytes; public RemoteInputChannel( @@ -139,7 +181,7 @@ public RemoteInputChannel( Counter numBytesIn, Counter numBuffersIn, ChannelStateWriter stateWriter, - ArrayDeque initialRecoveredBuffers) { + boolean needsRecovery) { super( inputGate, @@ -156,31 +198,12 @@ public RemoteInputChannel( this.initialCredit = networkBuffersPerChannel; this.connectionId = checkNotNull(connectionId); this.connectionManager = checkNotNull(connectionManager); - this.bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0); - this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo()); - - // Migrate recovered buffers from RecoveredInputChannel if provided. - // These buffers have been filtered but not yet consumed by the Task. - if (!initialRecoveredBuffers.isEmpty()) { - final int expectedCount = initialRecoveredBuffers.size(); - // Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel. - int seqNum = Integer.MIN_VALUE; - for (Buffer buffer : initialRecoveredBuffers) { - // subpartitionId is set to 0 for recovered buffers. This is correct because: - // 1) For single-subpartition channels, the only valid subpartition is 0. - // 2) For multi-subpartition channels (consumedSubpartitionIndexSet.size() > 1), - // RecoveryMetadata events embedded in the recovered buffer sequence track - // the actual subpartition context for proper routing. - SequenceBuffer sequenceBuffer = new SequenceBuffer(buffer, seqNum++, 0); - receivedBuffers.add(sequenceBuffer); - totalQueueSizeInBytes += buffer.getSize(); - } - checkState( - receivedBuffers.size() == expectedCount, - "Buffer migration failed: expected %s buffers but got %s", - expectedCount, - receivedBuffers.size()); - } + this.needsRecovery = needsRecovery; + this.bufferManager = + new BufferManager(inputGate.getMemorySegmentProvider(), this, 0, !needsRecovery); + this.channelStatePersister = + new ChannelStatePersister(checkNotNull(stateWriter), getChannelInfo()); + this.inRecovery = needsRecovery; } @VisibleForTesting @@ -188,6 +211,11 @@ void setExpectedSequenceNumber(int expectedSequenceNumber) { this.expectedSequenceNumber = expectedSequenceNumber; } + @VisibleForTesting + void completeUpstreamReadyForTest() { + upstreamReady.countDown(); + } + /** * Setup includes assigning exclusive buffers to this input channel, and this method should be * called only once after this input channel is created. @@ -201,6 +229,113 @@ void setup() throws IOException { bufferManager.requestExclusiveBuffers(initialCredit); } + // ------------------------------------------------------------------------ + // RecoverableInputChannel implementation + // ------------------------------------------------------------------------ + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + boolean wasEmpty; + synchronized (receivedBuffers) { + if (isReleased.get()) { + buffer.recycleBuffer(); + return; + } + // Migrate recovered buffers from RecoveredInputChannel. These buffers have been + // filtered but not yet consumed by the Task. They are appended to receivedBuffers so + // the consume path stays identical to the non-recovery case. + wasEmpty = appendRecoveredBuffer(buffer); + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + @Override + public void finishRecoveredBufferDelivery() throws IOException, InterruptedException { + upstreamReady.await(); + boolean wasEmpty; + synchronized (receivedBuffers) { + // A release may have opened the latch instead of the first buffer; bail out so we never + // append to a queue that releaseAllResources() already cleared. + if (isReleased.get()) { + return; + } + checkState(inRecovery, "Recovery delivery already finished."); + // Append the sentinel after the last recovered buffer. The consume path flips out of + // recovery (unstash + reopen credit) only once it polls this sentinel, guaranteeing all + // recovered buffers are consumed first. + wasEmpty = + appendRecoveredBuffer( + EventSerializer.toBuffer( + EndOfFetchedChannelStateEvent.INSTANCE, false)); + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + /** + * Flips out of recovery once the consume path polls the {@code EndOfFetchedChannelStateEvent} + * sentinel: releases the upstream events stashed during recovery so they are consumed after the + * recovered buffers, then reopens the suppressed credit notifications. + */ + @Override + public void onRecoveredStateConsumed() throws IOException { + synchronized (receivedBuffers) { + checkState(inRecovery, "Recovery already finished."); + inRecovery = false; + recoveryEventStash.forEach(receivedBuffers::add); + recoveryEventStash.clear(); + } + notifyChannelNonEmpty(); + // Credit notifications are suppressed while recovery borrows the exclusive buffers. + bufferManager.enableNotify(); + } + + @Override + public Buffer requestRecoveryBufferBlocking() throws InterruptedException, IOException { + upstreamReady.await(); + // If a release opened the latch instead of the first buffer, requestBufferBlocking() + // detects + // the released channel and throws CancelTaskException. + return bufferManager.requestBufferBlocking(); + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) throws IOException { + boolean wasEmpty = false; + synchronized (receivedBuffers) { + if (!isReleased.get() && inRecovery) { + wasEmpty = + appendRecoveredBuffer( + EventSerializer.toBuffer( + new RecoveryCheckpointBarrier(checkpointId), false)); + } + } + if (wasEmpty) { + notifyChannelNonEmpty(); + } + } + + /** + * Appends a recovered buffer (or {@code RecoveryCheckpointBarrier} sentinel) to {@link + * #receivedBuffers} with a recovery sequence number. + * + * @return {@code true} iff {@code receivedBuffers} transitioned from empty to non-empty. + */ + @GuardedBy("receivedBuffers") + private boolean appendRecoveredBuffer(Buffer buffer) { + boolean wasEmpty = receivedBuffers.isEmpty(); + // Recovered buffers carry no per-buffer subpartition id (NONE): they are snapshotted via + // the + // recovery path, never via getInflightBuffersUnsafe which is the only consumer of that + // field. + receivedBuffers.add(new SequenceBuffer(buffer, recoverySequenceNumber++, NONE)); + totalQueueSizeInBytes += buffer.getSize(); + return wasEmpty; + } + // ------------------------------------------------------------------------ // Consume // ------------------------------------------------------------------------ @@ -343,14 +478,18 @@ public boolean isReleased() { @Override void releaseAllResources() throws IOException { if (isReleased.compareAndSet(false, true)) { + // Unblock any thread awaiting upstreamReady (drain still in flight) so it falls + // through and observes the released state instead of deadlocking. + upstreamReady.countDown(); final ArrayDeque releasedBuffers; synchronized (receivedBuffers) { releasedBuffers = - receivedBuffers.stream() + Stream.concat(receivedBuffers.stream(), recoveryEventStash.stream()) .map(sb -> sb.buffer) .collect(Collectors.toCollection(ArrayDeque::new)); receivedBuffers.clear(); + recoveryEventStash.clear(); } bufferManager.releaseAllBuffers(releasedBuffers); @@ -558,6 +697,10 @@ public int getInitialCredit() { return initialCredit; } + public boolean needsRecovery() { + return needsRecovery; + } + public BufferProvider getBufferProvider() throws IOException { if (isReleased.get()) { return null; @@ -596,6 +739,11 @@ public void onBuffer(Buffer buffer, int sequenceNumber, int backlog, int subpart throws IOException { boolean recycleBuffer = true; + // The first buffer from the producer proves the upstream reader is registered and the + // connection is live; release any recovery-side awaiter. On later buffers this is a cheap + // idempotent no-op (the latch count is already zero). + upstreamReady.countDown(); + try { if (expectedSequenceNumber != sequenceNumber) { onError(new BufferReorderingException(expectedSequenceNumber, sequenceNumber)); @@ -633,7 +781,20 @@ public void onBuffer(Buffer buffer, int sequenceNumber, int backlog, int subpart firstPriorityEvent = addPriorityBuffer(sequenceBuffer); recycleBuffer = false; } else { - receivedBuffers.add(sequenceBuffer); + if (inRecovery) { + // The upstream has no credit until recovery delivery finishes, so it can + // only + // send events here, never data buffers. Stash ordinary events so they are + // consumed after the recovered buffers; data buffers are a protocol + // violation. + checkState( + !buffer.isBuffer(), + "Received live data buffer during recovery on channel %s", + getChannelInfo()); + recoveryEventStash.add(sequenceBuffer); + } else { + receivedBuffers.add(sequenceBuffer); + } recycleBuffer = false; if (dataType.requiresAnnouncement()) { firstPriorityEvent = addPriorityBuffer(announce(sequenceBuffer)); @@ -713,30 +874,117 @@ private void checkAnnouncedOnlyOnce(SequenceBuffer sequenceBuffer) { } /** - * Spills all queued buffers on checkpoint start. If barrier has already been received (and - * reordered), spill only the overtaken buffers. + * Persists inflight data on checkpoint start. During recovery, persists recovered buffers + * before the matching RecoveryCheckpointBarrier sentinel; after recovery, uses the normal + * remote-channel barrier sequence tracking and persists overtaken live buffers. */ public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { - synchronized (receivedBuffers) { - if (barrier.getId() < lastBarrierId) { - throw new CheckpointException( - String.format( - "Sequence number for checkpoint %d is not known (it was likely been overwritten by a newer checkpoint %d)", - barrier.getId(), lastBarrierId), - CheckpointFailureReason - .CHECKPOINT_SUBSUMED); // currently, at most one active unaligned - // checkpoint is possible - } else if (barrier.getId() > lastBarrierId) { - // This channel has received some obsolete barrier, older compared to the - // checkpointId - // which we are processing right now, and we should ignore that obsoleted checkpoint - // barrier sequence number. - resetLastBarrier(); + try { + List toPersist; + synchronized (receivedBuffers) { + if (inRecovery) { + toPersist = collectPreRecoveryBarrier(barrier.getId()); + } else { + if (barrier.getId() < lastBarrierId) { + // Currently, at most one active unaligned checkpoint is possible. + throw new CheckpointException( + String.format( + "Sequence number for checkpoint %d is not known (it was likely been overwritten by a newer checkpoint %d)", + barrier.getId(), lastBarrierId), + CheckpointFailureReason.CHECKPOINT_SUBSUMED); + } else if (barrier.getId() > lastBarrierId) { + // This channel has received some obsolete barrier, older compared to the + // checkpointId which we are processing right now, and we should ignore that + // obsoleted checkpoint barrier sequence number. + resetLastBarrier(); + } + toPersist = getInflightBuffersUnsafe(barrier.getId()); + } + channelStatePersister.startPersisting(barrier.getId(), toPersist); + if (inRecovery) { + // Recovered inflight buffers are collected in one shot and the upstream sends + // no + // data during recovery, so close the persist window immediately to keep the + // persister from carrying a pending state into later checkpoints. + channelStatePersister.stopPersisting(barrier.getId()); + } } + } catch (IOException e) { + throw new CheckpointException( + "Failed to extract recovered buffers for checkpoint " + barrier.getId(), + CheckpointFailureReason.CHECKPOINT_DECLINED, + e); + } + } - channelStatePersister.startPersisting( - barrier.getId(), getInflightBuffersUnsafe(barrier.getId())); + /** + * Walks {@link #receivedBuffers} (skipping priority events) up to the {@link + * RecoveryCheckpointBarrier} sentinel matching {@code checkpointId}, retaining each pre-barrier + * recovered data buffer and removing the sentinel. During recovery the upstream has no credit, + * so {@code receivedBuffers} holds only recovered buffers, sentinels, and priority events — no + * live data buffers. + * + * @throws IOException if no sentinel matching {@code checkpointId} is found (the snapshot + * protocol guarantees one must be present while the channel is in recovery). + */ + @GuardedBy("receivedBuffers") + private List collectPreRecoveryBarrier(long checkpointId) throws IOException { + assert Thread.holdsLock(receivedBuffers); + List retained = new ArrayList<>(); + SequenceBuffer sentinel = null; + try { + Iterator it = receivedBuffers.iterator(); + // Priority events are stored separately at the head and never carry recovered data. + Iterators.advance(it, receivedBuffers.getNumPriorityElements()); + while (it.hasNext()) { + SequenceBuffer sb = it.next(); + if (isRecoveryCheckpointBarrier(sb.buffer, checkpointId)) { + sentinel = sb; + break; + } + // Skip non-data events (e.g. the EndOfFetchedChannelStateEvent sentinel appended + // after the recovered buffers): only recovered data buffers are snapshotted. + if (sb.buffer.isBuffer()) { + retained.add(sb.buffer.retainBuffer()); + } + } + } catch (IOException e) { + releaseRetainedBuffers(retained); + throw e; + } + if (sentinel == null) { + releaseRetainedBuffers(retained); + throw new IOException( + "Missing RecoveryCheckpointBarrier for checkpoint " + + checkpointId + + " in receivedBuffers for channel " + + getChannelInfo()); + } + // receivedBuffers is a PrioritizedDeque whose iterator() is read-only; remove the matched + // sentinel by identity through its mutable removal API. + final SequenceBuffer matched = sentinel; + receivedBuffers.getAndRemove(sb -> sb == matched); + totalQueueSizeInBytes -= matched.buffer.getSize(); + matched.buffer.recycleBuffer(); + return retained; + } + + private static void releaseRetainedBuffers(List retained) { + for (Buffer buffer : retained) { + buffer.recycleBuffer(); + } + } + + private static boolean isRecoveryCheckpointBarrier(Buffer b, long checkpointId) + throws IOException { + if (b.isBuffer()) { + return false; } + AbstractEvent event = + EventSerializer.fromBuffer(b, RecoveryCheckpointBarrier.class.getClassLoader()); + b.setReaderIndex(0); + return event instanceof RecoveryCheckpointBarrier + && ((RecoveryCheckpointBarrier) event).getCheckpointId() == checkpointId; } public void checkpointStopped(long checkpointId) { @@ -909,13 +1157,13 @@ public void onError(Throwable cause) { } /** - * When receivedBuffers contains migrated buffers from RecoveredInputChannel, they can be read - * before requestSubpartitions(). In that case only check for errors. Once migrated buffers are - * drained, require full client initialization check. + * Allows reads while recovery data or already queued network data is available before the + * remote partition request is fully initialized. If neither recovery nor queued data can + * satisfy the read, require the partition request client to be initialized. */ private void checkReadability() throws IOException { assert Thread.holdsLock(receivedBuffers); - if (receivedBuffers.isEmpty()) { + if (!inRecovery && receivedBuffers.isEmpty()) { checkPartitionRequestQueueInitialized(); } else { checkError(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java index 2cfff6f5e7972..b76aa347fe644 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java @@ -20,13 +20,11 @@ import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; -import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import java.io.IOException; -import java.util.ArrayDeque; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -68,8 +66,7 @@ public class RemoteRecoveredInputChannel extends RecoveredInputChannel { } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) - throws IOException { + protected InputChannel toInputChannelInternal(boolean needsRecovery) throws IOException { RemoteInputChannel remoteInputChannel = new RemoteInputChannel( inputGate, @@ -85,8 +82,7 @@ protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffer numBytesIn, numBuffersIn, channelStateWriter, - remainingBuffers); - remoteInputChannel.setup(); + needsRecovery); return remoteInputChannel; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 438efa2f58bd5..50e359b4d4e32 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -245,6 +245,8 @@ public class SingleInputGate extends IndexedInputGate { private volatile boolean checkpointingDuringRecoveryEnabled = false; + private volatile boolean needsRecovery = false; + public SingleInputGate( String owningTaskName, int gateIndex, @@ -342,18 +344,13 @@ public boolean isCheckpointingDuringRecoveryEnabled() { } @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - synchronized (requestLock) { - List> futures = new ArrayList<>(numberOfInputChannels); - for (InputChannel inputChannel : inputChannels()) { - if (inputChannel instanceof RecoveredInputChannel) { - futures.add( - ((RecoveredInputChannel) inputChannel) - .getBufferFilteringCompleteFuture()); - } - } - return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); - } + public void setNeedsRecovery(boolean enabled) { + this.needsRecovery = enabled; + } + + @Override + public boolean needsRecovery() { + return needsRecovery; } @Override @@ -413,7 +410,7 @@ public void convertRecoveredInputChannels() { // order with onRecoveredStateBuffer() which acquires receivedBuffers // first and then inputChannelsWithData. InputChannel realInputChannel = - ((RecoveredInputChannel) inputChannel).toInputChannel(); + ((RecoveredInputChannel) inputChannel).toInputChannel(needsRecovery); inputChannel.releaseAllResources(); int buffersInUseCount = realInputChannel.getBuffersInUseCount(); @@ -595,6 +592,11 @@ public InputChannel getChannel(int channelIndex) { return channels[channelIndex]; } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + return channels[channelInfo.getInputChannelIdx()]; + } + // ------------------------------------------------------------------------ // Setup/Life-cycle // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index dda71c63be38f..fff12185156d4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -176,6 +176,13 @@ public InputChannel getChannel(int channelIndex) { .getChannel(channelIndex - inputGateChannelIndexOffsets[gateIndex]); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + // The member gate's local channel index is carried directly in channelInfo, so resolve + // through the owning gate instead of the gate-global getChannel(int) addressing. + return inputGatesByGateIndex.get(channelInfo.getGateIdx()).getChannel(channelInfo); + } + @Override public boolean isFinished() { return inputGatesWithRemainingData.isEmpty(); @@ -350,15 +357,6 @@ public CompletableFuture getStateConsumedFuture() { .toArray(new CompletableFuture[] {})); } - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return CompletableFuture.allOf( - inputGatesByGateIndex.values().stream() - .map(InputGate::getBufferFilteringCompleteFuture) - .collect(Collectors.toList()) - .toArray(new CompletableFuture[] {})); - } - @Override public void requestPartitions() throws IOException { for (InputGate inputGate : inputGatesByGateIndex.values()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java index 15182cedadb9f..2d717f66fe8a8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java @@ -35,7 +35,6 @@ import javax.annotation.Nullable; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Optional; import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY; @@ -171,21 +170,23 @@ public String toString() { public RemoteInputChannel toRemoteInputChannel( ConnectionID producerAddress, ResultPartitionID resultPartitionID) { - return new RemoteInputChannel( - inputGate, - getChannelIndex(), - resultPartitionID, - consumedSubpartitionIndexSet, - checkNotNull(producerAddress), - connectionManager, - initialBackoff, - maxBackoff, - partitionRequestListenerTimeout, - networkBuffersPerChannel, - metrics.getNumBytesInRemoteCounter(), - metrics.getNumBuffersInRemoteCounter(), - channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter, - new ArrayDeque<>()); + RemoteInputChannel channel = + new RemoteInputChannel( + inputGate, + getChannelIndex(), + resultPartitionID, + consumedSubpartitionIndexSet, + checkNotNull(producerAddress), + connectionManager, + initialBackoff, + maxBackoff, + partitionRequestListenerTimeout, + networkBuffersPerChannel, + metrics.getNumBytesInRemoteCounter(), + metrics.getNumBuffersInRemoteCounter(), + channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter, + false); + return channel; } public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID) { @@ -201,7 +202,8 @@ public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter, - new ArrayDeque<>()); + networkBuffersPerChannel, + false); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java index bff412f53b330..9b3abd611b78f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java @@ -80,6 +80,11 @@ public InputChannel getChannel(int channelIndex) { return inputGate.getChannel(channelIndex); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + return inputGate.getChannel(channelInfo); + } + @Override public int getGateIndex() { return inputGate.getGateIndex(); @@ -120,11 +125,6 @@ public CompletableFuture getStateConsumedFuture() { return inputGate.getStateConsumedFuture(); } - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return inputGate.getBufferFilteringCompleteFuture(); - } - @Override public void requestPartitions() throws IOException { inputGate.requestPartitions(); @@ -175,6 +175,16 @@ public boolean isCheckpointingDuringRecoveryEnabled() { return inputGate.isCheckpointingDuringRecoveryEnabled(); } + @Override + public void setNeedsRecovery(boolean enabled) { + inputGate.setNeedsRecovery(enabled); + } + + @Override + public boolean needsRecovery() { + return inputGate.needsRecovery(); + } + private BufferOrEvent updateMetrics(BufferOrEvent bufferOrEvent) { int incomingDataSize = bufferOrEvent.getSize(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CancelPartitionRequestTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CancelPartitionRequestTest.java index 2adbded7d9e90..d165f6a50cb4d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CancelPartitionRequestTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CancelPartitionRequestTest.java @@ -106,7 +106,8 @@ void testCancelPartitionRequest() throws Exception { pid, new ResultSubpartitionIndexSet(0), new InputChannelID(), - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Wait for the notification @@ -170,7 +171,8 @@ void testDuplicateCancel() throws Exception { pid, new ResultSubpartitionIndexSet(0), inputChannelId, - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Wait for the notification diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java index d4b162304b8e6..d58018036961a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java @@ -68,7 +68,6 @@ import java.io.IOException; import java.net.InetSocketAddress; -import java.util.ArrayDeque; import java.util.stream.Stream; import static org.apache.flink.runtime.io.network.netty.PartitionRequestQueueTest.blockChannel; @@ -956,7 +955,7 @@ private static class TestRemoteInputChannelForError extends RemoteInputChannel { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + false); this.expectedMessage = expectedMessage; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReaderTest.java index cd4a3103240ee..ba686adce0495 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReaderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReaderTest.java @@ -88,7 +88,7 @@ private static CreditBasedSequenceNumberingViewReader createNetworkSequenceViewR channel.close(); CreditBasedSequenceNumberingViewReader reader = new CreditBasedSequenceNumberingViewReader( - new InputChannelID(), initialCredit, queue); + new InputChannelID(), initialCredit, false, queue); reader.notifySubpartitionsCreated( TestingResultPartition.newBuilder() .setCreateSubpartitionViewFunction( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageServerSideSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageServerSideSerializationTest.java index 22b5420b616d2..f2e0e8146c8da 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageServerSideSerializationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageServerSideSerializationTest.java @@ -67,7 +67,8 @@ void testPartitionRequest() { new ResultPartitionID(), new ResultSubpartitionIndexSet(queueIndex), new InputChannelID(), - random.nextInt()); + random.nextInt(), + random.nextBoolean()); NettyMessage.PartitionRequest actual = encodeAndDecode(expected, channel); @@ -75,6 +76,7 @@ void testPartitionRequest() { assertThat(actual.queueIndexSet).isEqualTo(expected.queueIndexSet); assertThat(actual.receiverId).isEqualTo(expected.receiverId); assertThat(actual.credit).isEqualTo(expected.credit); + assertThat(actual.needsRecovery).isEqualTo(expected.needsRecovery); } @Test diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueueTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueueTest.java index 7046a8518093b..a4e15e4cf90d6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueueTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueueTest.java @@ -90,7 +90,8 @@ public void testNotifyReaderPartitionTimeout() throws Exception { ResultPartitionManager resultPartitionManager = new ResultPartitionManager(); ResultPartitionID resultPartitionId = new ResultPartitionID(); CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(new InputChannelID(0, 0), 10, queue); + new CreditBasedSequenceNumberingViewReader( + new InputChannelID(0, 0), 10, false, queue); reader.requestSubpartitionViewOrRegisterListener( resultPartitionManager, resultPartitionId, new ResultSubpartitionIndexSet(0)); @@ -132,9 +133,11 @@ void testNotifyReaderNonEmptyOnEmptyReaders() throws Exception { EmbeddedChannel channel = new EmbeddedChannel(queue); CreditBasedSequenceNumberingViewReader reader1 = - new CreditBasedSequenceNumberingViewReader(new InputChannelID(0, 0), 10, queue); + new CreditBasedSequenceNumberingViewReader( + new InputChannelID(0, 0), 10, false, queue); CreditBasedSequenceNumberingViewReader reader2 = - new CreditBasedSequenceNumberingViewReader(new InputChannelID(1, 1), 10, queue); + new CreditBasedSequenceNumberingViewReader( + new InputChannelID(1, 1), 10, false, queue); ResultSubpartitionView view1 = new EmptyAlwaysAvailableResultSubpartitionView(); reader1.notifySubpartitionsCreated( @@ -192,7 +195,8 @@ private void testBufferWriting(ResultSubpartitionView view) throws IOException { final InputChannelID receiverId = new InputChannelID(); final PartitionRequestQueue queue = new PartitionRequestQueue(); final CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, Integer.MAX_VALUE, queue); + new CreditBasedSequenceNumberingViewReader( + receiverId, Integer.MAX_VALUE, false, queue); final EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -287,7 +291,7 @@ void testEnqueueReaderByNotifyingEventBuffer() throws Exception { final InputChannelID receiverId = new InputChannelID(); final PartitionRequestQueue queue = new PartitionRequestQueue(); final CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 0, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 0, false, queue); final EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -340,7 +344,7 @@ void testEnqueueReaderByNotifyingBufferAndCredit() throws Exception { final InputChannelID receiverId = new InputChannelID(); final PartitionRequestQueue queue = new PartitionRequestQueue(); final CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 2, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 2, false, queue); final EmbeddedChannel channel = new EmbeddedChannel(queue); reader.addCredit(-2); @@ -421,7 +425,7 @@ void testEnqueueReaderByResumingConsumption() throws Exception { InputChannelID receiverId = new InputChannelID(); PartitionRequestQueue queue = new PartitionRequestQueue(); CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 2, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 2, false, queue); EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -461,7 +465,7 @@ void testAnnounceBacklog() throws Exception { PartitionRequestQueue queue = new PartitionRequestQueue(); InputChannelID receiverId = new InputChannelID(); CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 0, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 0, false, queue); EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -498,7 +502,7 @@ private void testCancelPartitionRequest(boolean isAvailableView) throws Exceptio final InputChannelID receiverId = new InputChannelID(); final PartitionRequestQueue queue = new PartitionRequestQueue(); final CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 2, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 2, false, queue); final EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); @@ -546,7 +550,7 @@ void testNotifyNewBufferSize() throws Exception { InputChannelID receiverId = new InputChannelID(); PartitionRequestQueue queue = new PartitionRequestQueue(); CreditBasedSequenceNumberingViewReader reader = - new CreditBasedSequenceNumberingViewReader(receiverId, 2, queue); + new CreditBasedSequenceNumberingViewReader(receiverId, 2, false, queue); EmbeddedChannel channel = new EmbeddedChannel(queue); reader.notifySubpartitionsCreated(partition, new ResultSubpartitionIndexSet(0)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java index e3cfb55e3400f..86d9588094f9a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java @@ -44,7 +44,6 @@ import org.junit.jupiter.api.Test; -import java.util.ArrayDeque; import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -96,7 +95,8 @@ void testRegisterResultPartitionBeforeRequest() throws Exception { resultPartition.getPartitionId(), new ResultSubpartitionIndexSet(0), new InputChannelID(), - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Wait for the notification @@ -144,7 +144,8 @@ void testRegisterResultPartitionAfterRequest() throws Exception { resultPartition.getPartitionId(), new ResultSubpartitionIndexSet(0), new InputChannelID(), - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Register result partition after partition request @@ -213,7 +214,8 @@ public void releasePartitionRequestListener( pid, new ResultSubpartitionIndexSet(0), remoteInputChannel.getInputChannelId(), - Integer.MAX_VALUE)) + Integer.MAX_VALUE, + false)) .await(); // Wait for the notification @@ -250,7 +252,7 @@ private static class TestRemoteInputChannelForPartitionNotFound extends RemoteIn new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + false); this.latch = latch; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java index 7f19b199582e7..1fe0c00e79a18 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java @@ -81,7 +81,7 @@ void testAcknowledgeAllRecordsProcessed() throws IOException { // Creates and registers the view to netty. NetworkSequenceViewReader viewReader = new CreditBasedSequenceNumberingViewReader( - inputChannelID, 2, partitionRequestQueue); + inputChannelID, 2, false, partitionRequestQueue); viewReader.notifySubpartitionsCreated(resultPartition, new ResultSubpartitionIndexSet(0)); partitionRequestQueue.notifyReaderCreated(viewReader); @@ -149,7 +149,7 @@ private static class TestViewReader extends CreditBasedSequenceNumberingViewRead TestViewReader( InputChannelID receiverId, int initialCredit, PartitionRequestQueue requestQueue) { - super(receiverId, initialCredit, requestQueue); + super(receiverId, initialCredit, false, requestQueue); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ServerTransportErrorHandlingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ServerTransportErrorHandlingTest.java index a4f753fe1e7d2..bfc4e01a153cc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ServerTransportErrorHandlingTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/ServerTransportErrorHandlingTest.java @@ -103,7 +103,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { new ResultPartitionID(), new ResultSubpartitionIndexSet(0), new InputChannelID(), - Integer.MAX_VALUE)); + Integer.MAX_VALUE, + false)); // Wait for the notification assertThat(sync.await(TestingUtils.TESTING_DURATION.toMillis(), TimeUnit.MILLISECONDS)) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java index 08f65d9fe7265..b171dc41c74ec 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java @@ -34,7 +34,6 @@ import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import java.net.InetSocketAddress; -import java.util.ArrayDeque; import static org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager; @@ -56,6 +55,7 @@ public class InputChannelBuilder { private int maxBackoff = 0; private int partitionRequestListenerTimeout = 0; private int networkBuffersPerChannel = 2; + private boolean needsRecovery = false; private InputChannelMetrics metrics = InputChannelTestUtils.newUnregisteredInputChannelMetrics(); @@ -115,6 +115,11 @@ public InputChannelBuilder setNetworkBuffersPerChannel(int networkBuffersPerChan return this; } + public InputChannelBuilder setNeedsRecovery(boolean needsRecovery) { + this.needsRecovery = needsRecovery; + return this; + } + public InputChannelBuilder setMetrics(InputChannelMetrics metrics) { this.metrics = metrics; return this; @@ -166,7 +171,8 @@ public LocalInputChannel buildLocalChannel(SingleInputGate inputGate) { metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), stateWriter, - new ArrayDeque<>()); + networkBuffersPerChannel, + needsRecovery); } public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) { @@ -184,7 +190,7 @@ public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) { metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter(), stateWriter, - new ArrayDeque<>()); + needsRecovery); } public LocalRecoveredInputChannel buildLocalRecoveredChannel(SingleInputGate inputGate) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index 86bda9866d204..9861d28024c07 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -19,10 +19,12 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.disk.NoOpFileChannelManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; @@ -61,7 +63,6 @@ import org.mockito.stubbing.Answer; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -669,7 +670,7 @@ void testReceivingBuffersInUseBeforeSubpartitionViewInitialization() throws Exce @Test void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { - // given: Local input channel with recovered buffers in toBeConsumedBuffers + // given: Local input channel with recovered buffers in the recovery queue ResultSubpartitionView subpartitionView = InputChannelTestUtils.createResultSubpartitionView( createFilledFinishedBufferConsumer(4096), @@ -678,12 +679,6 @@ void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { new TestingResultPartitionManager(subpartitionView); final SingleInputGate inputGate = createSingleInputGate(1); - // Create 3 recovered buffers - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); - recoveredBuffers.add(TestBufferFactory.createBuffer(32)); - final LocalInputChannel localChannel = new LocalInputChannel( inputGate, @@ -697,10 +692,16 @@ void testGetBuffersInUseCountIncludesToBeConsumedBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(localChannel); + // Create 3 recovered buffers + localChannel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(32)); + localChannel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(32)); + localChannel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(32)); + // then: Before requesting subpartitions, buffers in use should include recovered buffers assertThat(localChannel.getBuffersInUseCount()).isEqualTo(3); assertThat(localChannel.unsynchronizedGetNumberOfQueuedBuffers()).isEqualTo(3); @@ -718,10 +719,6 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { // given: LocalInputChannel with recovered buffers migrated from RecoveredInputChannel SingleInputGate inputGate = createSingleInputGate(1); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); - LocalInputChannel channel = new LocalInputChannel( inputGate, @@ -735,10 +732,16 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(10)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(20)); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + // then: Can read recovered buffers even before requestSubpartitions() Optional first = channel.getNextBuffer(); assertThat(first).isPresent(); @@ -755,11 +758,6 @@ void testCheckpointStartedPersistsRecoveredBuffers() throws Exception { // given: Local input channel with recovered buffers SingleInputGate inputGate = new SingleInputGateBuilder().build(); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); - recoveredBuffers.add(TestBufferFactory.createBuffer(30)); - RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); LocalInputChannel channel = @@ -775,19 +773,28 @@ void testCheckpointStartedPersistsRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), stateWriter, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(channel); - // when: Checkpoint is started + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(10)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(20)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(30)); + // Mirror snapshotAndInsertBarriers: push the sentinel before + // checkpointStarted scans recoveredQueue. + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + CheckpointOptions options = CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); stateWriter.start(1L, options); CheckpointBarrier barrier = new CheckpointBarrier(1L, 0L, options); + // when: Checkpoint is started channel.checkpointStarted(barrier); - // then: All 3 recovered buffers should be persisted as inflight data List persistedBuffers = stateWriter.getAddedInput().get(channel.getChannelInfo()); + // then: All 3 recovered buffers should be persisted as inflight data assertThat(persistedBuffers).isNotNull().hasSize(3); assertThat(persistedBuffers.stream().mapToInt(Buffer::getSize).toArray()) .containsExactly(10, 20, 30); @@ -824,9 +831,6 @@ void testPriorityEventFailsFastWhenSubpartitionViewIsNull() throws Exception { // given: Local input channel with recovered buffers but NO subpartition view initialized SingleInputGate inputGate = new SingleInputGateBuilder().build(); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - LocalInputChannel channel = new LocalInputChannel( inputGate, @@ -840,9 +844,11 @@ void testPriorityEventFailsFastWhenSubpartitionViewIsNull() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(10)); // Do NOT call channel.requestSubpartitions() — subpartitionView stays null channel.notifyPriorityEvent(0); @@ -962,11 +968,6 @@ private static ChannelAndSubpartition createChannelWithRecoveredBuffers( TestingResultPartitionManager partitionManager = new TestingResultPartitionManager(subpartitionView); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - for (int size : recoveredBufferSizes) { - recoveredBuffers.add(TestBufferFactory.createBuffer(size)); - } - LocalInputChannel channel = new LocalInputChannel( inputGate, @@ -980,9 +981,15 @@ private static ChannelAndSubpartition createChannelWithRecoveredBuffers( new SimpleCounter(), new SimpleCounter(), stateWriter, - recoveredBuffers); + 2, + true); inputGate.setInputChannels(channel); + + for (int size : recoveredBufferSizes) { + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(size)); + } + channel.requestSubpartitions(); return new ChannelAndSubpartition(channel, subpartition); @@ -998,6 +1005,333 @@ private static class ChannelAndSubpartition { } } + // --------------------------------------------------------------------------------------------- + // RecoverableInputChannel push-based recovery tests + // --------------------------------------------------------------------------------------------- + + private static LocalInputChannel newPushOnlyLocalChannel( + SingleInputGate inputGate, ChannelStateWriter stateWriter) { + return new LocalInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + new ResultPartitionManager(), + new TaskEventDispatcher(), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + stateWriter, + 2, + true); + } + + @Test + void testOnRecoveredStateBufferEnqueues() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(11)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(22)); + + Optional first = channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getSize()).isEqualTo(11); + Optional second = channel.getNextBuffer(); + assertThat(second).isPresent(); + assertThat(second.get().buffer().getSize()).isEqualTo(22); + } + + @Test + void testOnRecoveredStateBufferOnReleasedChannelIsSilentlyRecycled() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + channel.releaseAllResources(); + + Buffer b = TestBufferFactory.createBuffer(33); + channel.onRecoveredStateBuffer(b); + assertThat(b.isRecycled()).isTrue(); + } + + @Test + void testOnRecoveredStateBufferNotifiesChannelNonEmptyOnEmptyToNonEmptyTransition() + throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + + CompletableFuture availability = inputGate.getAvailableFuture(); + assertThat(availability).isNotDone(); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + assertThat(availability).isDone(); + } + + @Test + void testInRecoveryBoundaryFlagFalseQueueEmptyReturnsEmpty() throws Exception { + // Drive into the (flag=false, queue=empty) boundary by pushing one buffer, polling it, + // and verifying the channel does not yet expose a master-path result. + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.getNextBuffer(); + // Queue is empty; no finishRecoveredBufferDelivery was called. Without the subpartitionView + // active, the channel returns empty. + Optional result = channel.getNextBuffer(); + assertThat(result).isNotPresent(); + } + + @Test + void testInRecoveryBoundaryFlagFalseQueueNonEmptyPolls() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(7)); + Optional r = channel.getNextBuffer(); + assertThat(r).isPresent(); + assertThat(r.get().buffer().getSize()).isEqualTo(7); + } + + @Test + void testInRecoveryBoundaryFlagTrueQueueNonEmptyPolls() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(8)); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + Optional r = channel.getNextBuffer(); + assertThat(r).isPresent(); + assertThat(r.get().buffer().getSize()).isEqualTo(8); + } + + @Test + void testFinishWithNoRecoveredBuffersEmitsSentinelThenFallsToMasterPath() throws Exception { + // With no recovered buffers, finish still appends the EndOfFetchedChannelStateEvent + // sentinel; getNextBuffer returns it. Consuming the sentinel flips the channel out of + // recovery, after which the master path takes over: without a subpartition view it raises + // IllegalStateException via checkAndWaitForSubpartitionView, proving the recovery branch is + // no longer swallowing the call. + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, ChannelStateWriter.NO_OP); + inputGate.setInputChannels(channel); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + + Optional sentinel = channel.getNextBuffer(); + assertThat(sentinel).isPresent(); + assertThat(EventSerializer.fromBuffer(sentinel.get().buffer(), getClass().getClassLoader())) + .isInstanceOf(EndOfFetchedChannelStateEvent.class); + + channel.onRecoveredStateConsumed(); + assertThatThrownBy(channel::getNextBuffer) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Queried for a buffer before requesting the subpartition"); + } + + @Test + void testPriorityEventDuringRecoveryFetchedFromSubpartitionView() throws Exception { + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + ChannelAndSubpartition ctx = createChannelWithRecoveredBuffers(stateWriter, 10, 20); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + ctx.subpartition.add( + EventSerializer.toBufferConsumer(new CheckpointBarrier(1L, 0L, options), true)); + ctx.channel.notifyPriorityEvent(0); + + Optional first = ctx.channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getDataType().hasPriority()).isTrue(); + } + + @Test + void testPriorityEventDuringRecoveryResetAfterNonPriority() throws Exception { + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + ChannelAndSubpartition ctx = createChannelWithRecoveredBuffers(stateWriter, 10); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + ctx.subpartition.add( + EventSerializer.toBufferConsumer(new CheckpointBarrier(1L, 0L, options), true)); + ctx.subpartition.add(createFilledFinishedBufferConsumer(32)); + ctx.channel.notifyPriorityEvent(0); + + Optional priority = ctx.channel.getNextBuffer(); + assertThat(priority).isPresent(); + Optional recovered = ctx.channel.getNextBuffer(); + assertThat(recovered).isPresent(); + assertThat(recovered.get().buffer().isBuffer()).isTrue(); + assertThat(recovered.get().buffer().getSize()).isEqualTo(10); + } + + @Test + void testCheckpointStartedScansRecoveredBuffersUpToBarrier() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + Buffer b2 = TestBufferFactory.createBuffer(2); + Buffer b3 = TestBufferFactory.createBuffer(3); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer(b2); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(b3); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + + List persisted = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persisted).hasSize(2); + assertThat(persisted.stream().mapToInt(Buffer::getSize).toArray()).containsExactly(1, 2); + } + + @Test + void testCheckpointStartedFailsWhenRecoveryBarrierIsMissing() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + channel.onRecoveredStateBuffer(b1); + int refCntBefore = b1.refCnt(); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + + assertThatThrownBy(() -> channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options))) + .isInstanceOf(CheckpointException.class) + .hasMessageContaining("Failed to extract recovered buffers for checkpoint 1") + .hasRootCauseMessage( + "Missing RecoveryCheckpointBarrier for checkpoint 1 in recoveredBuffers for channel " + + channel.getChannelInfo()); + assertThat(b1.refCnt()).isEqualTo(refCntBefore); + assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isEmpty(); + } + + @Test + void testCheckpointStartedRetainsPreBarrierBuffers() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + + int before = b1.refCnt(); + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + // Pre-barrier buffers are retained for the writer; their ref-count stays at or above the + // pre-checkpoint value. + assertThat(b1.refCnt()).isGreaterThanOrEqualTo(before); + } + + @Test + void testCheckpointStartedRemovesSentinel() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + + Optional h1 = channel.getNextBuffer(); + assertThat(h1).isPresent(); + Optional h2 = channel.getNextBuffer(); + assertThat(h2).isPresent(); + assertThat(h2.get().buffer().getSize()).isEqualTo(2); + } + + @Test + void testCheckpointStartedNestedCpIds() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(2L), false)); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + + List persisted = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persisted).hasSize(1); + assertThat(persisted.get(0).getSize()).isEqualTo(1); + } + + @Test + void testCheckpointStartedNotInRecoveryUsesMasterPath() throws Exception { + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + // Channel starts in-recovery; finish appends the sentinel and consuming it flips the + // channel to not-in-recovery so checkpointStarted exercises the master path instead of the + // recovery branch. + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + channel.getNextBuffer(); + channel.onRecoveredStateConsumed(); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + + assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isNullOrEmpty(); + } + + @Test + void testReceivedBuffersHasNoLiveDataBufferIsTrueOnLocal() throws Exception { + // Local has no receivedBuffers; the helper is trivially true. We exercise the + // in-recovery checkpointStarted branch without any "live data" infrastructure to confirm + // the channel does not crash. + SingleInputGate inputGate = new SingleInputGateBuilder().build(); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + + CheckpointOptions options = + CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); + stateWriter.start(1L, options); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options)); + } + // --------------------------------------------------------------------------------------------- /** Returns the configured number of buffers for each channel in a random order. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java new file mode 100644 index 0000000000000..a2083468b811a --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.util.TestBufferFactory; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class LocalRecoveredInputChannelTest { + + @Test + void testToInputChannelRequiresEmptyRecoveredBuffers() throws Exception { + SingleInputGate inputGate = + new SingleInputGateBuilder().setCheckpointingDuringRecoveryEnabled(true).build(); + LocalRecoveredInputChannel recoveredChannel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .buildLocalRecoveredChannel(inputGate); + + Buffer buffer = TestBufferFactory.createBuffer(11); + recoveredChannel.onRecoveredStateBuffer(buffer); + + try { + recoveredChannel.finishReadRecoveredState(); + assertThatThrownBy(() -> recoveredChannel.toInputChannel(true)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received buffer should be empty"); + } finally { + recoveredChannel.releaseAllResources(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java index f40fd09702ede..f606bb3eec37c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java @@ -22,14 +22,13 @@ import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; -import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import org.junit.jupiter.api.Test; import java.io.IOException; -import java.util.ArrayDeque; import static org.apache.flink.runtime.checkpoint.CheckpointOptions.unaligned; import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault; @@ -39,16 +38,6 @@ /** Tests for {@link RecoveredInputChannel}. */ class RecoveredInputChannelTest { - @Test - void testConversionOnlyPossibleAfterBufferFilteringComplete() { - // toInputChannel() always checks bufferFilteringCompleteFuture regardless of config - for (boolean configEnabled : new boolean[] {true, false}) { - assertThatThrownBy(() -> buildChannel(configEnabled).toInputChannel()) - .isInstanceOf(IllegalStateException.class) - .hasMessageContaining("buffer filtering is not complete"); - } - } - @Test void testRequestPartitionsImpossible() { assertThatThrownBy(() -> buildChannel(false).requestSubpartitions()) @@ -72,83 +61,74 @@ void testCheckpointStartImpossible() { @Test void testToInputChannelAllowedWhenBufferFilteringCompleteAndConfigEnabled() throws IOException { - // When config is enabled, conversion is allowed when bufferFilteringCompleteFuture is done + // When config is enabled, conversion is allowed after finishReadRecoveredState() + // without requiring stateConsumedFuture to be done. TestableRecoveredInputChannel channel = buildTestableChannel(true); - // Initially, conversion should fail - assertThatThrownBy(() -> channel.toInputChannel()) - .isInstanceOf(IllegalStateException.class) - .hasMessageContaining("buffer filtering is not complete"); - - // After finishReadRecoveredState(), bufferFilteringCompleteFuture should be done channel.finishReadRecoveredState(); - assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); assertThat(channel.getStateConsumedFuture()).isNotDone(); // Conversion should now succeed (no exception) - InputChannel converted = channel.toInputChannel(); + InputChannel converted = channel.toInputChannel(true); assertThat(converted).isNotNull(); } @Test void testToInputChannelAllowedWhenStateConsumedAndConfigDisabled() throws IOException { - // When config is disabled, conversion requires both bufferFilteringCompleteFuture - // and stateConsumedFuture to be done + // When config is disabled, conversion requires stateConsumedFuture to be done. TestableRecoveredInputChannel channel = buildTestableChannel(false); - // Initially, conversion should fail (buffer filtering not complete) - assertThatThrownBy(() -> channel.toInputChannel()) - .isInstanceOf(IllegalStateException.class) - .hasMessageContaining("buffer filtering is not complete"); - - // After finishReadRecoveredState(), bufferFilteringCompleteFuture is done - // but stateConsumedFuture is not channel.finishReadRecoveredState(); - assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); assertThat(channel.getStateConsumedFuture()).isNotDone(); - // Conversion should still fail because stateConsumedFuture is not done - assertThatThrownBy(() -> channel.toInputChannel()) + // Conversion should fail because stateConsumedFuture is not done + assertThatThrownBy(() -> channel.toInputChannel(true)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining("recovered state is not fully consumed"); - // Consume the EndOfInputChannelStateEvent to complete stateConsumedFuture + // Consuming the EndOfInputChannelStateEvent should complete the future. + // getNextBuffer() returns empty when it encounters the event internally. assertThat(channel.getNextBuffer()).isNotPresent(); assertThat(channel.getStateConsumedFuture()).isDone(); // Now conversion should succeed - InputChannel converted = channel.toInputChannel(); + InputChannel converted = channel.toInputChannel(true); assertThat(converted).isNotNull(); } @Test - void testBufferFilteringCompleteFutureAlwaysCompletes() throws IOException { - // finishReadRecoveredState() unconditionally completes bufferFilteringCompleteFuture - for (boolean configEnabled : new boolean[] {true, false}) { - RecoveredInputChannel channel = buildChannel(configEnabled); - assertThat(channel.getBufferFilteringCompleteFuture()).isNotDone(); - channel.finishReadRecoveredState(); - assertThat(channel.getBufferFilteringCompleteFuture()).isDone(); - } + void testToInputChannelRequiresEmptyRecoveredBuffers() throws IOException { + TestableRecoveredInputChannel channel = buildTestableChannel(true); + + channel.onRecoveredStateBuffer(BufferBuilderTestUtils.buildSomeBuffer()); + channel.finishReadRecoveredState(); + + assertThatThrownBy(() -> channel.toInputChannel(true)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received buffer should be empty"); } @Test - void testStateConsumedFutureCompletesAfterConsumingAllBuffers() throws IOException { - // This test verifies that stateConsumedFuture completes after consuming - // EndOfInputChannelStateEvent regardless of the config setting - for (boolean configEnabled : new boolean[] {true, false}) { - RecoveredInputChannel channel = buildChannel(configEnabled); + void testStateConsumedFutureCompletesAfterLegacySentinelIsConsumed() throws IOException { + RecoveredInputChannel channel = buildChannel(false); - assertThat(channel.getStateConsumedFuture()).isNotDone(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); - channel.finishReadRecoveredState(); - assertThat(channel.getStateConsumedFuture()).isNotDone(); + channel.finishReadRecoveredState(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); - // Consuming the EndOfInputChannelStateEvent should complete the future. - // getNextBuffer() returns empty when it encounters the event internally. - assertThat(channel.getNextBuffer()).isNotPresent(); - assertThat(channel.getStateConsumedFuture()).isDone(); - } + assertThat(channel.getNextBuffer()).isNotPresent(); + assertThat(channel.getStateConsumedFuture()).isDone(); + } + + @Test + void testStateConsumedFutureDoesNotCompleteWithoutLegacySentinel() throws IOException { + RecoveredInputChannel channel = buildChannel(true); + + channel.finishReadRecoveredState(); + + assertThat(channel.getNextBuffer()).isNotPresent(); + assertThat(channel.getStateConsumedFuture()).isNotDone(); } private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEnabled) { @@ -169,7 +149,7 @@ private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEn new SimpleCounter(), 10) { @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal(boolean needsRecovery) { throw new AssertionError("channel conversion succeeded"); } }; @@ -210,7 +190,7 @@ private static class TestableRecoveredInputChannel extends RecoveredInputChannel } @Override - protected InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) { + protected InputChannel toInputChannelInternal(boolean needsRecovery) { return new TestInputChannel(inputGate, 0); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index e47de93c9e8bd..58223dea191ae 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -27,6 +27,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; +import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.io.network.ConnectionID; @@ -37,6 +40,7 @@ import org.apache.flink.runtime.io.network.TestingConnectionManager; import org.apache.flink.runtime.io.network.TestingPartitionRequestClient; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.Buffer.DataType; @@ -72,6 +76,7 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.net.InetSocketAddress; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.List; @@ -2079,15 +2084,8 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { // given: RemoteInputChannel with recovered buffers migrated from RecoveredInputChannel SingleInputGate inputGate = createSingleInputGate(1); - ArrayDeque recoveredBuffers = new ArrayDeque<>(); - recoveredBuffers.add(TestBufferFactory.createBuffer(10)); - recoveredBuffers.add(TestBufferFactory.createBuffer(20)); - ConnectionID connectionId = - new ConnectionID( - org.apache.flink.runtime.clusterframework.types.ResourceID.generate(), - new java.net.InetSocketAddress("localhost", 0), - 0); + new ConnectionID(ResourceID.generate(), new InetSocketAddress("localhost", 0), 0); RemoteInputChannel channel = new RemoteInputChannel( inputGate, @@ -2104,10 +2102,15 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { new SimpleCounter(), new SimpleCounter(), ChannelStateWriter.NO_OP, - recoveredBuffers); + true); inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(10)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(20)); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + // then: Can read recovered buffers even before requestSubpartitions() Optional first = channel.getNextBuffer(); assertThat(first).isPresent(); @@ -2119,6 +2122,488 @@ void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception { assertThat(second.get().buffer().getSize()).isEqualTo(20); } + // --------------------------------------------------------------------------------------------- + // RecoverableInputChannel push-based recovery tests + // --------------------------------------------------------------------------------------------- + + @Test + void testOnRecoveredStateBufferEnqueues() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(11); + Buffer b2 = TestBufferFactory.createBuffer(22); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer(b2); + + Optional first = channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getSize()).isEqualTo(11); + Optional second = channel.getNextBuffer(); + assertThat(second).isPresent(); + assertThat(second.get().buffer().getSize()).isEqualTo(22); + } + + @Test + void testRecoveredBuffersConsumedBeforeStashedEventsThenSentinelFlipsRecovery() + throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + // Recovered buffer arrives via the drain; an ordinary upstream event arrives via onBuffer + // while still in recovery and must be stashed (events carry no credit). + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(11)); + // backlog=-1: events carry no backlog, and this channel has no floating-buffer pool wired. + channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE, false), 0, -1, 0); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + + // Recovered buffer is consumed first. + Optional recovered = channel.getNextBuffer(); + assertThat(recovered).isPresent(); + assertThat(recovered.get().buffer().getSize()).isEqualTo(11); + + // Then the sentinel; the stashed event is not yet visible (still in recovery). + Optional sentinel = channel.getNextBuffer(); + assertThat(sentinel).isPresent(); + assertThat(EventSerializer.fromBuffer(sentinel.get().buffer(), getClass().getClassLoader())) + .isInstanceOf(EndOfFetchedChannelStateEvent.class); + + // The gate consumes the sentinel externally: flips out of recovery and unstashes the event. + channel.onRecoveredStateConsumed(); + Optional stashed = channel.getNextBuffer(); + assertThat(stashed).isPresent(); + assertThat(EventSerializer.fromBuffer(stashed.get().buffer(), getClass().getClassLoader())) + .isInstanceOf(EndOfPartitionEvent.class); + } + + @Test + void testOnBufferRejectsLiveDataBufferDuringRecovery() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + assertThatThrownBy(() -> channel.onBuffer(TestBufferFactory.createBuffer(1), 0, 0, 0)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received live data buffer during recovery"); + } + + @Test + void testOnRecoveredStateBufferOnReleasedChannelIsSilentlyRecycled() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + channel.releaseAllResources(); + + Buffer b = TestBufferFactory.createBuffer(33); + channel.onRecoveredStateBuffer(b); + + assertThat(b.isRecycled()).isTrue(); + } + + @Test + void testOnRecoveredStateBufferNotifiesChannelNonEmptyOnEmptyToNonEmptyTransition() + throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + CompletableFuture availability = inputGate.getAvailableFuture(); + assertThat(availability).isNotDone(); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + assertThat(availability).isDone(); + } + + @Test + void testInRecoveryBoundaryFlagFalseQueueEmptyReturnsEmpty() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + // Force in-recovery + empty queue: push then poll a buffer, then push another while + // delaying finish. After the consumer drains the staged buffer, queue=empty and + // flag=false (no finishRecoveredBufferDelivery called yet). getNextBuffer must return + // empty. + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.getNextBuffer(); + // Simulate an explicit recovery context where the producer signals "not done yet". + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + channel.getNextBuffer(); + // The boundary case "flag=false (drain still running) + queue empty" should return empty. + // To set this state explicitly, we deliberately do not call finishReadRecoveredState. + Optional result = channel.getNextBuffer(); + assertThat(result).isNotPresent(); + } + + @Test + void testInRecoveryBoundaryFlagFalseQueueNonEmptyPolls() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(7)); + + Optional r = channel.getNextBuffer(); + assertThat(r).isPresent(); + assertThat(r.get().buffer().getSize()).isEqualTo(7); + } + + @Test + void testInRecoveryBoundaryFlagTrueQueueNonEmptyPolls() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(8)); + channel.completeUpstreamReadyForTest(); + channel.finishRecoveredBufferDelivery(); + + Optional r = channel.getNextBuffer(); + assertThat(r).isPresent(); + assertThat(r.get().buffer().getSize()).isEqualTo(8); + } + + @Test + void testFinishWithNoRecoveredBuffersEmitsSentinelThenFallsToMasterPath() throws Exception { + // Wire a real network pool so requestSubpartitions() can succeed and the master path can + // poll receivedBuffers. + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setNeedsRecovery(true).buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + channel.completeUpstreamReadyForTest(); + // Even with no recovered buffers, finish appends the EndOfFetchedChannelStateEvent + // sentinel so the consume path can flip out of recovery in order. + channel.finishRecoveredBufferDelivery(); + inputGate.requestPartitions(); + + Optional sentinel = channel.getNextBuffer(); + assertThat(sentinel).isPresent(); + assertThat( + EventSerializer.fromBuffer( + sentinel.get().buffer(), getClass().getClassLoader())) + .isInstanceOf(EndOfFetchedChannelStateEvent.class); + + // Consuming the sentinel (done externally by the gate) flips the channel out of + // recovery; afterwards the master path is taken and there is no more queued data. + channel.onRecoveredStateConsumed(); + assertThat(channel.getNextBuffer()).isNotPresent(); + } finally { + networkBufferPool.destroy(); + } + } + + @Test + void testMoreAvailableNoneWhenLastRecoveredBufferAndDrainNotFinished() throws Exception { + // While the channel is in recovery and the drain has not finished, the last currently + // queued recovered buffer must report NONE as its next data type: no live data can enter + // receivedBuffers (the upstream has no credit), so there is nothing else to expose yet. + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setNeedsRecovery(true) + .setChannelFactory( + (builder, gate) -> + builder.setNeedsRecovery(true).buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(11)); + + Optional recoveredBuf = channel.getNextBuffer(); + assertThat(recoveredBuf).isPresent(); + assertThat(recoveredBuf.get().buffer().getSize()).isEqualTo(11); + // Drain not finished and queue now empty: nothing more is available yet. + assertThat(recoveredBuf.get().moreAvailable()).isFalse(); + } finally { + networkBufferPool.destroy(); + } + } + + @Test + void testPriorityEventDuringRecoveryViaAddPriorityBuffer() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setNeedsRecovery(true).buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(11)); + + CheckpointBarrier barrier = new CheckpointBarrier(1L, 0L, UNALIGNED); + channel.onBuffer(toBuffer(barrier, true), 0, 0, 0); + + Optional first = channel.getNextBuffer(); + assertThat(first).isPresent(); + assertThat(first.get().buffer().getDataType().hasPriority()).isTrue(); + + Optional second = channel.getNextBuffer(); + assertThat(second).isPresent(); + assertThat(second.get().buffer().getSize()).isEqualTo(11); + } finally { + networkBufferPool.destroy(); + } + } + + @Test + void testCheckpointStartedScansRecoveredBuffersUpToBarrier() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + Buffer b2 = TestBufferFactory.createBuffer(2); + Buffer b3 = TestBufferFactory.createBuffer(3); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer(b2); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(b3); + + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + + List persisted = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persisted).hasSize(2); + assertThat(persisted.stream().mapToInt(Buffer::getSize).toArray()).containsExactly(1, 2); + } + + @Test + void testCheckpointStartedFailsWhenRecoveryBarrierIsMissing() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + channel.onRecoveredStateBuffer(b1); + int refCntBefore = b1.refCnt(); + + stateWriter.start(1L, UNALIGNED); + + assertThatThrownBy( + () -> channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED))) + .isInstanceOf(CheckpointException.class) + .hasMessageContaining("Failed to extract recovered buffers for checkpoint 1") + .hasRootCauseMessage( + "Missing RecoveryCheckpointBarrier for checkpoint 1 in receivedBuffers for channel " + + channel.getChannelInfo()); + assertThat(b1.refCnt()).isEqualTo(refCntBefore); + assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isEmpty(); + } + + @Test + void testCheckpointStartedRetainsPreBarrierBuffers() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + Buffer b1 = TestBufferFactory.createBuffer(1); + channel.onRecoveredStateBuffer(b1); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + + int before = b1.refCnt(); + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + // After retainBuffer the buffer remains live for both the queue read and the writer copy. + assertThat(b1.refCnt()).isGreaterThanOrEqualTo(before); + } + + @Test + void testCheckpointStartedRemovesSentinel() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + + Optional head = channel.getNextBuffer(); + assertThat(head).isPresent(); + Optional nextHead = channel.getNextBuffer(); + assertThat(nextHead).isPresent(); + assertThat(nextHead.get().buffer().getSize()).isEqualTo(2); + } + + @Test + void testCheckpointStartedNestedCpIds() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(2)); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(2L), false)); + + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + + List persisted1 = stateWriter.getAddedInput().get(channel.getChannelInfo()); + assertThat(persisted1).hasSize(1); + assertThat(persisted1.get(0).getSize()).isEqualTo(1); + } + + @Test + void testCheckpointStartedNotInRecoveryUsesMasterPath() throws Exception { + SingleInputGate inputGate = createSingleInputGate(1); + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + RemoteInputChannel channel = + InputChannelBuilder.newBuilder() + .setStateWriter(stateWriter) + .buildRemoteChannel(inputGate); + inputGate.setInputChannels(channel); + channel.requestSubpartitions(); + stateWriter.start(7L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(7L, 0L, UNALIGNED)); + + assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isNullOrEmpty(); + } + + @Test + void testReceivedBuffersHasNoLiveDataBufferDetectsLiveData() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setNeedsRecovery(true).buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + // During recovery the upstream has no credit and can only send events. A live data + // buffer is a protocol violation that onBuffer must reject at the entry point. + assertThatThrownBy(() -> channel.onBuffer(TestBufferFactory.createBuffer(1), 0, 0, 0)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received live data buffer during recovery"); + } finally { + networkBufferPool.destroy(); + } + } + + @Test + void testReceivedBuffersHasNoLiveDataBufferAcceptsPriorityOnly() throws Exception { + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 4096); + try { + RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) + .setSegmentProvider(networkBufferPool) + .setChannelFactory( + (builder, gate) -> + builder.setStateWriter(stateWriter) + .setNeedsRecovery(true) + .buildRemoteChannel(gate)) + .build(); + inputGate.setup(); + RemoteInputChannel channel = (RemoteInputChannel) inputGate.getChannel(0); + + channel.onRecoveredStateBuffer(TestBufferFactory.createBuffer(1)); + // Mirror what snapshotAndInsertBarriers does: push the + // RecoveryCheckpointBarrier sentinel so collectPreRecoveryBarrier finds it. + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(1L), false)); + // Priority event in receivedBuffers is OK (!isBuffer()). + CheckpointBarrier priorityBarrier = new CheckpointBarrier(1L, 0L, UNALIGNED); + channel.onBuffer(toBuffer(priorityBarrier, true), 0, 0, 0); + + stateWriter.start(1L, UNALIGNED); + channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED)); + } finally { + networkBufferPool.destroy(); + } + } + private static final class TestBufferPool extends NoOpBufferPool { @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java new file mode 100644 index 0000000000000..938d040093e32 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.util.TestBufferFactory; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class RemoteRecoveredInputChannelTest { + + @Test + void testToInputChannelRequiresEmptyRecoveredBuffers() throws Exception { + SingleInputGate inputGate = + new SingleInputGateBuilder().setCheckpointingDuringRecoveryEnabled(true).build(); + RemoteRecoveredInputChannel recoveredChannel = + InputChannelBuilder.newBuilder() + .setStateWriter(ChannelStateWriter.NO_OP) + .buildRemoteRecoveredChannel(inputGate); + + Buffer buffer = TestBufferFactory.createBuffer(13); + recoveredChannel.onRecoveredStateBuffer(buffer); + + try { + recoveredChannel.finishReadRecoveredState(); + assertThatThrownBy(() -> recoveredChannel.toInputChannel(true)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Received buffer should be empty"); + } finally { + recoveredChannel.releaseAllResources(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java index a4da811f8a32a..a1d1afbfc27c7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java @@ -85,6 +85,8 @@ public class SingleInputGateBuilder { private boolean isCheckpointingDuringRecoveryEnabled = false; + private boolean isNeedsRecovery = false; + public SingleInputGateBuilder setPartitionProducerStateProvider( PartitionProducerStateProvider partitionProducerStateProvider) { @@ -174,6 +176,11 @@ public SingleInputGateBuilder setCheckpointingDuringRecoveryEnabled(boolean enab return this; } + public SingleInputGateBuilder setNeedsRecovery(boolean enabled) { + this.isNeedsRecovery = enabled; + return this; + } + public SingleInputGate build() { SingleInputGate gate = new SingleInputGate( @@ -189,6 +196,9 @@ public SingleInputGate build() { bufferSize, createThroughputCalculator.apply(bufferDebloatConfiguration), maybeCreateBufferDebloater(gateIndex)); + // Propagate before channel construction so RecoverableInputChannel implementations read + // the intended flag in their constructor and initialise their recovery state correctly. + gate.setNeedsRecovery(isNeedsRecovery); if (channelFactory != null) { gate.setInputChannels( IntStream.range(0, numberOfChannels) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index b2cc9d7ce3c9c..f7f0b744fb9fd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -142,36 +142,6 @@ void testCheckpointsDeclinedUnlessStateConsumed() { .isInstanceOf(CheckpointException.class); } - @Test - void testBufferFilteringCompleteFutureAggregation() throws Exception { - final NettyShuffleEnvironment environment = createNettyShuffleEnvironment(); - final SingleInputGate inputGate = createInputGate(environment); - try (Closer closer = Closer.create()) { - closer.register(environment::close); - closer.register(inputGate::close); - - // Enable unaligned during recovery for this test so that - // bufferFilteringCompleteFuture is completed by finishReadRecoveredState() - inputGate.setCheckpointingDuringRecoveryEnabled(true); - inputGate.setup(); - - // Initially, the aggregated future should not be completed - assertThat(inputGate.getBufferFilteringCompleteFuture()).isNotDone(); - - // After finishing read recovered state, bufferFilteringCompleteFuture should be - // completed (only when config is enabled) - inputGate.finishReadRecoveredState(); - assertThat(inputGate.getBufferFilteringCompleteFuture()).isDone(); - - // stateConsumedFuture should not be completed until data is consumed - assertThat(inputGate.getStateConsumedFuture()).isNotDone(); - - // Consuming the EndOfInputChannelStateEvent should complete stateConsumedFuture - inputGate.pollNext(); - assertThat(inputGate.getStateConsumedFuture()).isDone(); - } - } - /** * Tests {@link InputGate#setup()} should create the respective {@link BufferPool} and assign * exclusive buffers for {@link RemoteInputChannel}s, but should not request partitions. diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java index 14a3654a66639..1759d10cd5c23 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointBarrier; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.io.network.api.EndOfData; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; @@ -31,8 +32,10 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; +import java.util.Deque; import java.util.Optional; import java.util.Queue; import java.util.concurrent.CompletableFuture; @@ -45,7 +48,7 @@ import static org.assertj.core.api.Assertions.assertThat; /** A mocked input channel. */ -public class TestInputChannel extends InputChannel { +public class TestInputChannel extends InputChannel implements RecoverableInputChannel { private final Queue buffers = new ConcurrentLinkedQueue<>(); @@ -259,6 +262,45 @@ public void notifyRequiredSegmentId(int subpartitionId, int segmentId) { requiredSegmentIdFuture.complete(segmentId); } + private final Deque recoveredBuffersSpy = new ArrayDeque<>(); + private boolean finishRecoveredBufferDeliveryCalled = false; + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + recoveredBuffersSpy.add(buffer); + } + + @Override + public void finishRecoveredBufferDelivery() { + finishRecoveredBufferDeliveryCalled = true; + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) throws IOException { + if (!finishRecoveredBufferDeliveryCalled || !recoveredBuffersSpy.isEmpty()) { + recoveredBuffersSpy.add( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(checkpointId), false)); + } + } + + @Override + public Buffer requestRecoveryBufferBlocking() { + throw new UnsupportedOperationException("TestInputChannel does not back recovery drain"); + } + + @Override + public void onRecoveredStateConsumed() { + // No-op in this test stub. + } + + public Deque getRecoveredBuffersSpy() { + return recoveredBuffersSpy; + } + + public boolean isFinishRecoveredBufferDeliveryCalled() { + return finishRecoveredBufferDeliveryCalled; + } + public void assertReturnedEventsAreRecycled() { assertReturnedBuffersAreRecycled(false, true); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java index 1ed1a42a66ea0..419246137e8c2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -18,14 +18,12 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.io.PullingAsyncDataInput; import org.apache.flink.runtime.io.network.api.StopMode; import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils; import org.apache.flink.runtime.io.network.partition.NoOpResultSubpartitionView; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager; import org.junit.jupiter.api.Test; @@ -277,62 +275,6 @@ void testGetChannelWithShiftedGateIndexes() { assertThat(unionInputGate.getChannel(1)).isEqualTo(inputChannel2); } - @Test - void testBufferFilteringCompleteFutureAggregation() throws IOException { - // Create 2 SingleInputGates, each with 1 RecoveredInputChannel - SingleInputGate ig1 = - new SingleInputGateBuilder().setCheckpointingDuringRecoveryEnabled(true).build(); - RecoveredInputChannel channel1 = buildRecoveredChannel(ig1); - ig1.setInputChannels(channel1); - - SingleInputGate ig2 = - new SingleInputGateBuilder() - .setSingleInputGateIndex(1) - .setCheckpointingDuringRecoveryEnabled(true) - .build(); - RecoveredInputChannel channel2 = buildRecoveredChannel(ig2); - ig2.setInputChannels(channel2); - - UnionInputGate union = new UnionInputGate(ig1, ig2); - - // Initially, bufferFilteringCompleteFuture should not be done - assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); - assertThat(union.getStateConsumedFuture()).isNotDone(); - - // Complete buffer filtering on first gate only - channel1.finishReadRecoveredState(); - assertThat(ig1.getBufferFilteringCompleteFuture()).isDone(); - assertThat(union.getBufferFilteringCompleteFuture()).isNotDone(); - - // Complete buffer filtering on second gate - channel2.finishReadRecoveredState(); - assertThat(ig2.getBufferFilteringCompleteFuture()).isDone(); - assertThat(union.getBufferFilteringCompleteFuture()).isDone(); - - // State consumed futures should still NOT be done (state not consumed yet) - assertThat(union.getStateConsumedFuture()).isNotDone(); - } - - private static RecoveredInputChannel buildRecoveredChannel(SingleInputGate inputGate) { - return new RecoveredInputChannel( - inputGate, - 0, - new ResultPartitionID(), - new ResultSubpartitionIndexSet(0), - 0, - 0, - new SimpleCounter(), - new SimpleCounter(), - 10) { - @Override - protected InputChannel toInputChannelInternal( - java.util.ArrayDeque - remainingBuffers) { - throw new UnsupportedOperationException(); - } - }; - } - @Test void testEmptyPull() throws IOException, InterruptedException { final SingleInputGate inputGate1 = createInputGate(1); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java index 584aeb7eb9089..1c58310cdbef1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java @@ -57,11 +57,6 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return CompletableFuture.completedFuture(null); - } - @Override public void finishReadRecoveredState() {} @@ -86,6 +81,11 @@ public InputChannel getChannel(int channelIndex) { throw new UnsupportedOperationException(); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + throw new UnsupportedOperationException(); + } + @Override public void setChannelStateWriter(ChannelStateWriter channelStateWriter) {} @@ -150,4 +150,12 @@ public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} public boolean isCheckpointingDuringRecoveryEnabled() { return false; } + + @Override + public void setNeedsRecovery(boolean enabled) {} + + @Override + public boolean needsRecovery() { + return false; + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java index 71b2c43f3306a..1d5729231378a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java @@ -80,11 +80,6 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return CompletableFuture.completedFuture(null); - } - @Override public void finishReadRecoveredState() {} @@ -101,6 +96,11 @@ public InputChannel getChannel(int channelIndex) { throw new UnsupportedOperationException(); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + throw new UnsupportedOperationException(); + } + @Override public List getChannelInfos() { return IntStream.range(0, numberOfChannels) @@ -212,4 +212,12 @@ public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} public boolean isCheckpointingDuringRecoveryEnabled() { return false; } + + @Override + public void setNeedsRecovery(boolean enabled) {} + + @Override + public boolean needsRecovery() { + return false; + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java index b850a7cc55370..b3f0aec9dc5e5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java @@ -37,7 +37,6 @@ import org.apache.flink.runtime.taskmanager.NettyShuffleEnvironmentConfiguration; import java.io.IOException; -import java.util.ArrayDeque; /** * A benchmark-specific input gate factory which overrides the respective methods of creating {@link @@ -130,7 +129,8 @@ public TestLocalInputChannel( metrics.getNumBytesInLocalCounter(), metrics.getNumBuffersInLocalCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + 0, + false); } @Override @@ -186,7 +186,7 @@ public TestRemoteInputChannel( metrics.getNumBytesInRemoteCounter(), metrics.getNumBuffersInRemoteCounter(), ChannelStateWriter.NO_OP, - new ArrayDeque<>()); + false); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java index 619873c387d08..2448a2ac524f6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java @@ -178,6 +178,11 @@ public InputChannel getChannel(int channelIndex) { throw new UnsupportedOperationException(); } + @Override + public InputChannel getChannel(InputChannelInfo channelInfo) { + throw new UnsupportedOperationException(); + } + @Override public List getChannelInfos() { return IntStream.range(0, numberOfChannels) @@ -263,11 +268,6 @@ public CompletableFuture getStateConsumedFuture() { return CompletableFuture.completedFuture(null); } - @Override - public CompletableFuture getBufferFilteringCompleteFuture() { - return CompletableFuture.completedFuture(null); - } - @Override public void finishReadRecoveredState() {} @@ -294,5 +294,13 @@ public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} public boolean isCheckpointingDuringRecoveryEnabled() { return false; } + + @Override + public void setNeedsRecovery(boolean enabled) {} + + @Override + public boolean needsRecovery() { + return false; + } } } From 4f8a1be5d0879ee527278fb6ae4dca65f04c7e56 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Fri, 22 May 2026 13:27:53 +0200 Subject: [PATCH 04/16] [FLINK-38544][checkpoint] Phase 3: SpillFile + filter writer phase - New SpillFile: append-only segmented disk store with 64 MiB segments, reference counter + cleanedUp guard, and Snapshot view over segments and entries. All public signatures (append, snapshot, readBytesAt, acquire, release, isClosed) land in this commit; later phases only fill in bodies. - New FilteredBufferWriter: prefilter + postfilter buffer accumulator, flushing the post-filter buffer to disk on rotation. - New SpillFileWriter: thin facade exposing SpillFile lifecycle to filter callers. - RecoveredChannelStateHandler.recover filter branch routes output to a SpillFile instead of channel.onRecoveredStateBuffer; the accumulator's prefilter and postfilter buffers are sourced from the source channel's exclusive pool (no heap fallback). - InputChannelRecoveredStateHandler exposes getProducedSpillFile so Phase 4 drain wiring can pick up the frozen file after filter completes; spill-tmp-directories argument is required (no backward-compat shim). Design: requirements/38544/phase3_spill_writer/design.md (cherry picked from commit 2cbbbd67808d4f440a2e228b5c13e4ba4aa34ac7) --- .../core/fs/OffsetAwareOutputStream.java | 2 +- .../channel/ChannelStateFilteringHandler.java | 178 +----- .../channel/ChannelStateSerializer.java | 14 + .../channel/RecoveredChannelStateHandler.java | 532 ++++++++++++++---- .../SequentialChannelStateReaderImpl.java | 61 +- .../channel/AbstractSpillingHandlerTest.java | 162 ++++++ .../ChannelStateFilteringHandlerTest.java | 79 +++ .../GateFilterHandlerBufferOwnershipTest.java | 112 ++-- .../channel/GateFilterHandlerTest.java | 124 ++-- ...InputChannelRecoveredStateHandlerTest.java | 80 ++- ...dChannelStateHandlerFilterRoutingTest.java | 308 ++++++++++ .../RecoveredChannelStateHandlerTest.java | 2 +- .../SequentialChannelStateReaderImplTest.java | 9 +- .../checkpoint/channel/TestSpillWriter.java | 98 ++++ 14 files changed, 1322 insertions(+), 439 deletions(-) create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/AbstractSpillingHandlerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandlerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerFilterRoutingTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestSpillWriter.java diff --git a/flink-core/src/main/java/org/apache/flink/core/fs/OffsetAwareOutputStream.java b/flink-core/src/main/java/org/apache/flink/core/fs/OffsetAwareOutputStream.java index 3ee4b761e1b30..375c95da25afa 100644 --- a/flink-core/src/main/java/org/apache/flink/core/fs/OffsetAwareOutputStream.java +++ b/flink-core/src/main/java/org/apache/flink/core/fs/OffsetAwareOutputStream.java @@ -35,7 +35,7 @@ public final class OffsetAwareOutputStream implements Closeable { private long position; - OffsetAwareOutputStream(OutputStream currentOut, long position) { + public OffsetAwareOutputStream(OutputStream currentOut, long position) { this.currentOut = checkNotNull(currentOut); this.position = position; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java index b257c3b40544e..0b6976068d7ab 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java @@ -101,24 +101,18 @@ public static ChannelStateFilteringHandler createFromContext( } /** - * Filters a recovered buffer from the specified virtual channel, returning new buffers - * containing only the records that belong to the current subtask. - * - *

One source buffer may produce 0 to N result buffers: 0 if all records are filtered out, - * and potentially more than 1 when a spanning record completes in this buffer. The deserializer - * caches partial record data from previous buffers, so the output may contain data that was not - * in the current source buffer, causing the total output size to exceed one buffer capacity. - * This can happen with any spanning record regardless of its size. - * - * @return filtered buffers, possibly empty if all records were filtered out. + * Filters {@code sourceBuffer} through the virtual channel identified by {@code gateIndex} / + * {@code oldChannelIndex}, appending each surviving record (length-prefixed) into {@code + * outputSerializer}. One call may emit 0..N records depending on the filter result and whether + * records spanning previous buffers complete here. The caller owns the segment boundary. */ - public List filterAndRewrite( + public void filterAndRewrite( int gateIndex, int oldSubtaskIndex, int oldChannelIndex, Buffer sourceBuffer, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { + DataOutputSerializer outputSerializer) + throws IOException { if (gateIndex < 0 || gateIndex >= gateHandlers.length) { throw new IllegalStateException( @@ -135,8 +129,8 @@ public List filterAndRewrite( + gateIndex + ". This gate is not a network input and should not have recovered buffers."); } - return gateHandler.filterAndRewrite( - oldSubtaskIndex, oldChannelIndex, sourceBuffer, bufferSupplier); + gateHandler.filterAndRewrite( + oldSubtaskIndex, oldChannelIndex, sourceBuffer, outputSerializer); } /** Returns {@code true} if any virtual channel has a partial (spanning) record pending. */ @@ -215,7 +209,8 @@ private static GateFilterHandler createGateHandler( : VirtualChannelRecordFilterFactory.createPassThroughFilter(); RecordDeserializer> deserializer = - createDeserializer(filterContext.getTmpDirectories()); + new SpillingAdaptiveSpanningRecordDeserializer<>( + filterContext.getTmpDirectories()); VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter); gateVirtualChannels.put(key, vc); @@ -246,26 +241,10 @@ private static int[] getOldChannelIndexes(RescaleMappings channelMapping, int nu return oldIndexes.stream().mapToInt(Integer::intValue).toArray(); } - private static RecordDeserializer> createDeserializer( - String[] tmpDirectories) { - if (tmpDirectories != null && tmpDirectories.length > 0) { - return new SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories); - } else { - String[] defaultDirs = new String[] {System.getProperty("java.io.tmpdir")}; - return new SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs); - } - } - // ------------------------------------------------------------------------------------------- // Inner classes // ------------------------------------------------------------------------------------------- - /** Provides buffers for re-serializing filtered records. Implementations may block. */ - @FunctionalInterface - public interface BufferSupplier { - Buffer requestBufferBlocking() throws IOException, InterruptedException; - } - /** * Handles record filtering for a single input gate. Each gate has its own serializer and set of * virtual channels, allowing different gates to handle different record types independently. @@ -275,8 +254,6 @@ static class GateFilterHandler { private final Map> virtualChannels; private final StreamElementSerializer serializer; private final DeserializationDelegate deserializationDelegate; - private final DataOutputSerializer outputSerializer; - private final byte[] lengthBuffer = new byte[4]; GateFilterHandler( Map> virtualChannels, @@ -284,23 +261,21 @@ static class GateFilterHandler { this.virtualChannels = checkNotNull(virtualChannels); this.serializer = checkNotNull(serializer); this.deserializationDelegate = new NonReusingDeserializationDelegate<>(serializer); - this.outputSerializer = new DataOutputSerializer(128); } /** * Deserializes records from {@code sourceBuffer}, applies the virtual channel's record - * filter, and immediately re-serializes each surviving record into output buffers. + * filter, and re-serializes each surviving record into {@code outputSerializer}. No + * intermediate network buffer is used; the caller owns the segment boundary. */ - List filterAndRewrite( + void filterAndRewrite( int oldSubtaskIndex, int oldChannelIndex, Buffer sourceBuffer, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { + DataOutputSerializer outputSerializer) + throws IOException { boolean sourceBufferOwnershipTransferred = false; - List resultBuffers = new ArrayList<>(); - Buffer currentBuffer = null; try { SubtaskConnectionDescriptor key = new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); @@ -319,132 +294,33 @@ List filterAndRewrite( while (true) { DeserializationResult result = vc.getNextRecord(deserializationDelegate); if (result.isFullRecord()) { - if (currentBuffer == null) { - currentBuffer = bufferSupplier.requestBufferBlocking(); - } - currentBuffer = - serializeElement( - deserializationDelegate.getInstance(), - currentBuffer, - resultBuffers, - bufferSupplier); + serializeElement(deserializationDelegate.getInstance(), outputSerializer); } if (result.isBufferConsumed()) { break; } } - - if (currentBuffer != null) { - if (currentBuffer.readableBytes() > 0) { - resultBuffers.add(currentBuffer); - } else { - currentBuffer.recycleBuffer(); - } - currentBuffer = null; - } - - return resultBuffers; } catch (Throwable t) { if (!sourceBufferOwnershipTransferred) { sourceBuffer.recycleBuffer(); } - // Avoid double-recycle: currentBuffer may already be the last element in - // resultBuffers if serializeElement added it before the exception. - if (currentBuffer != null - && (resultBuffers.isEmpty() - || resultBuffers.get(resultBuffers.size() - 1) != currentBuffer)) { - currentBuffer.recycleBuffer(); - } - for (Buffer buf : resultBuffers) { - buf.recycleBuffer(); - } - resultBuffers.clear(); throw t; } } /** - * Serializes a single stream element into the current buffer using the length-prefixed - * format (4-byte big-endian length + record bytes) expected by Flink's record - * deserializers. Spills into new buffers from {@code bufferSupplier} when needed. - * - * @return the buffer to continue writing into (may differ from the input buffer). + * Appends one stream element as a length-prefixed record. Reserves the 4B prefix, + * serializes the element, then backfills the length, because {@code outputSerializer} + * already holds the segment header and earlier records, so the prefix cannot be written + * from a fixed offset. */ - private Buffer serializeElement( - StreamElement element, - Buffer currentBuffer, - List resultBuffers, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { - outputSerializer.clear(); + private void serializeElement(StreamElement element, DataOutputSerializer outputSerializer) + throws IOException { + int startPos = outputSerializer.length(); + outputSerializer.writeInt(0); // length placeholder serializer.serialize(element, outputSerializer); - int recordLength = outputSerializer.length(); - - writeLengthToBuffer(recordLength); - currentBuffer = - writeDataToBuffer( - lengthBuffer, 0, 4, currentBuffer, resultBuffers, bufferSupplier); - - byte[] serializedData = outputSerializer.getSharedBuffer(); - currentBuffer = - writeDataToBuffer( - serializedData, - 0, - recordLength, - currentBuffer, - resultBuffers, - bufferSupplier); - return currentBuffer; - } - - private void writeLengthToBuffer(int length) { - lengthBuffer[0] = (byte) (length >> 24); - lengthBuffer[1] = (byte) (length >> 16); - lengthBuffer[2] = (byte) (length >> 8); - lengthBuffer[3] = (byte) length; - } - - /** - * Writes data to the current buffer, spilling into new buffers from {@code bufferSupplier} - * when the current one is full. - * - * @return the buffer to continue writing into (may differ from the input buffer). - */ - private Buffer writeDataToBuffer( - byte[] data, - int dataOffset, - int dataLength, - Buffer currentBuffer, - List resultBuffers, - BufferSupplier bufferSupplier) - throws IOException, InterruptedException { - int offset = dataOffset; - int remaining = dataLength; - - while (remaining > 0) { - int writableBytes = currentBuffer.getMaxCapacity() - currentBuffer.getSize(); - - if (writableBytes == 0) { - // Buffer is full, transfer ownership to resultBuffers - resultBuffers.add(currentBuffer); - currentBuffer = bufferSupplier.requestBufferBlocking(); - writableBytes = currentBuffer.getMaxCapacity(); - } - - int bytesToWrite = Math.min(remaining, writableBytes); - currentBuffer - .getMemorySegment() - .put( - currentBuffer.getMemorySegmentOffset() + currentBuffer.getSize(), - data, - offset, - bytesToWrite); - currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite); - - offset += bytesToWrite; - remaining -= bytesToWrite; - } - return currentBuffer; + int recordLength = outputSerializer.length() - startPos - Integer.BYTES; + outputSerializer.writeIntUnsafe(recordLength, startPos); } boolean hasPartialData() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializer.java index 252d25c2e29fd..ec858460dd89b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateSerializer.java @@ -43,6 +43,8 @@ interface ChannelStateSerializer { void writeData(DataOutputStream stream, Buffer... flinkBuffers) throws IOException; + void writeData(DataOutputStream stream, InputStream input, int length) throws IOException; + void readHeader(InputStream stream) throws IOException; int readLength(InputStream stream) throws IOException; @@ -165,6 +167,18 @@ public void writeData(DataOutputStream stream, Buffer... flinkBuffers) throws IO } } + @Override + public void writeData(DataOutputStream stream, InputStream input, int length) + throws IOException { + Preconditions.checkArgument(length >= 0, "negative state size"); + stream.writeInt(length); + long copied = input.transferTo(stream); + if (copied != length) { + throw new java.io.EOFException( + "Unexpected EOF: expected " + length + " bytes of segment body, got " + copied); + } + } + private int getSize(Buffer[] buffers) { int len = 0; for (Buffer buffer : buffers) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java index ca01ff37bd369..3475169ff82b4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.checkpoint.channel; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.core.fs.OffsetAwareOutputStream; +import org.apache.flink.core.memory.DataOutputSerializer; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; @@ -38,13 +40,21 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.io.BufferedOutputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.UUID; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateByteBuffer.wrap; import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; interface RecoveredChannelStateHandler extends AutoCloseable { @@ -73,28 +83,396 @@ void recover(Info info, int oldSubtaskIndex, BufferWithContext bufferWi throws IOException, InterruptedException; } -class InputChannelRecoveredStateHandler +/** + * Abstract base for all input-channel recovery handlers. Holds the channel mapping logic shared by + * all three variants (no-spilling, spilling-no-filtering, spilling-with-filtering). + * + *

Subclasses implement {@link #recover} according to their specific recovery mode and override + * {@link #closeInternal()} to release mode-specific resources. + * + *

Use the static {@link #create} factory to obtain the correct concrete subclass. + */ +abstract class AbstractInputChannelRecoveredStateHandler implements RecoveredChannelStateHandler { - private final InputGate[] inputGates; - private final InflightDataRescalingDescriptor channelMapping; + final InputGate[] inputGates; + final InflightDataRescalingDescriptor channelMapping; + final Map rescaledChannels = new HashMap<>(); + final Map oldToNewMappings = new HashMap<>(); + + AbstractInputChannelRecoveredStateHandler( + InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping) { + this.inputGates = inputGates; + this.channelMapping = channelMapping; + } + + /** + * Factory that selects the correct subclass based on {@code checkpointingDuringRecoveryEnabled} + * and whether a {@code filteringHandler} is present. + * + *

    + *
  • {@code false} → {@link NoSpillingHandler} + *
  • {@code true} and {@code filteringHandler == null} → {@link SpillingNoFilteringHandler} + *
  • {@code true} and {@code filteringHandler != null} → {@link + * SpillingWithFilteringHandler} + *
+ */ + static AbstractInputChannelRecoveredStateHandler create( + InputGate[] inputGates, + InflightDataRescalingDescriptor channelMapping, + boolean checkpointingDuringRecoveryEnabled, + @Nullable ChannelStateFilteringHandler filteringHandler, + int memorySegmentSize, + String[] spillTmpDirectories) { + if (!checkpointingDuringRecoveryEnabled) { + return new NoSpillingHandler(inputGates, channelMapping); + } + if (filteringHandler == null) { + return new SpillingNoFilteringHandler(inputGates, channelMapping, spillTmpDirectories); + } + return new SpillingWithFilteringHandler( + inputGates, + channelMapping, + filteringHandler, + memorySegmentSize, + spillTmpDirectories); + } + + /** Default buffer allocation from the network buffer pool, used by non-filtering modes. */ + @Override + public BufferWithContext getBuffer(InputChannelInfo channelInfo) + throws IOException, InterruptedException { + RecoveredInputChannel channel = getMappedChannels(channelInfo); + Buffer buffer = channel.requestBufferBlocking(); + return new BufferWithContext<>(wrap(buffer), buffer); + } + + /** + * Returns the {@link FetchedChannelState} produced during spilling, or {@code null} if spilling + * was not active (i.e., {@link NoSpillingHandler}). + */ + @Nullable + FetchedChannelState getProducedChannelState() { + return null; + } + + @Override + public void close() throws IOException { + closeInternal(); + } + + /** Hook for subclasses to release their own resources. Called by {@link #close()}. */ + void closeInternal() throws IOException {} + + RecoveredInputChannel getMappedChannels(InputChannelInfo channelInfo) { + return rescaledChannels.computeIfAbsent(channelInfo, this::calculateMapping); + } + + @Nonnull + private RecoveredInputChannel calculateMapping(InputChannelInfo info) { + final RescaleMappings oldToNewMapping = + oldToNewMappings.computeIfAbsent( + info.getGateIdx(), idx -> channelMapping.getChannelMapping(idx).invert()); + int[] mappedIndexes = oldToNewMapping.getMappedIndexes(info.getInputChannelIdx()); + checkState( + mappedIndexes.length == 1, + "One buffer is only distributed to one target InputChannel since " + + "one buffer is expected to be processed once by the same task."); + return getChannel(info.getGateIdx(), mappedIndexes[0]); + } + + private RecoveredInputChannel getChannel(int gateIndex, int subPartitionIndex) { + final InputChannel inputChannel = inputGates[gateIndex].getChannel(subPartitionIndex); + if (!(inputChannel instanceof RecoveredInputChannel)) { + throw new IllegalStateException( + "Cannot restore state to a non-recovered input channel: " + inputChannel); + } + return (RecoveredInputChannel) inputChannel; + } +} + +/** + * Recovery handler for the case where checkpointing during recovery is disabled. Delivers recovered + * buffers directly into the input channel via {@code onRecoveredStateBuffer}, with no spill file. + */ +class NoSpillingHandler extends AbstractInputChannelRecoveredStateHandler { + + NoSpillingHandler(InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping) { + super(inputGates, channelMapping); + } + + @Override + public void recover( + InputChannelInfo channelInfo, + int oldSubtaskIndex, + BufferWithContext bufferWithContext) + throws IOException, InterruptedException { + Buffer buffer = bufferWithContext.context; + try { + if (buffer.readableBytes() > 0) { + RecoveredInputChannel channel = getMappedChannels(channelInfo); + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer( + new SubtaskConnectionDescriptor( + oldSubtaskIndex, channelInfo.getInputChannelIdx()), + false)); + channel.onRecoveredStateBuffer(buffer.retainBuffer()); + } + } finally { + buffer.recycleBuffer(); + } + } +} + +/** + * Intermediate abstract base for the two spilling variants. Owns the on-disk spill format end to + * end: a single reusable {@link DataOutputSerializer} accumulates one channel's segment, the + * segment header is backfilled with the body length at seal time, and sealed segments are flushed + * to the current file stream with 64 MB-bounded rotation. + * + *

Disk format

+ * + *
+ * [ 4B BE int: gate idx      ]   segment header: written once per channel segment
+ * [ 4B BE int: channel idx   ]
+ * [ 4B BE int: buffer length ]   segment body byte count (backfilled at segment seal)
+ *   [ 4B BE int: record length ]  repeated for every record in this segment
+ *   [ N bytes: serialized record ]
+ * [ 4B BE int: gate idx      ]   next segment header (channel switch or post-rotation)
+ * ...
+ * 
+ * + *

The body byte count is only known after the whole segment is written, so each segment is first + * accumulated in {@link #segmentSerializer} (header written at open with a zero placeholder) and + * {@link DataOutputSerializer#writeIntUnsafe} backfills the length at seal. A segment is one + * uninterrupted run of records for a single channel; file rotation happens only after a segment is + * fully sealed, so a segment never crosses a file boundary. + */ +abstract class AbstractSpillingHandler extends AbstractInputChannelRecoveredStateHandler { + + /** Byte offset of the {@code bufferLength} field within a segment's header. */ + static final int BUFFER_LENGTH_HEADER_OFFSET = 2 * Integer.BYTES; + + /** Total size of the segment header in bytes: gateIdx + channelIdx + bufferLength. */ + static final int SEGMENT_HEADER_BYTES = 3 * Integer.BYTES; + + final String[] spillTmpDirectories; + + public static final long DEFAULT_SPILL_FILE_SIZE_BYTES = 64L * 1024 * 1024; + + /** Soft per-file size bound that triggers rotation between segments. */ + private final long maxFileSizeBytes; + + /** + * Accumulates the current segment: the header followed by the body, which is either + * length-prefixed filtered records or verbatim pass-through bytes, depending on the subclass. + * Reused across segments via {@code clear()}. + */ + private final DataOutputSerializer segmentSerializer = new DataOutputSerializer(256); + + /** + * Spill files written so far, in order. The {@link FetchedChannelState} handoff is built from + * this list once writing is sealed; an empty list means the handler never spilled any bytes, so + * it produces no state. + */ + private final List files = new ArrayList<>(); + + /** + * Unique directory for this handler's spill files; created lazily when the first file opens. + */ + private final Path baseDir; + + /** + * Output stream to the current spill file; tracks the bytes written so far via {@link + * OffsetAwareOutputStream#getLength()} to decide when to rotate. Null before the first segment + * is flushed. + */ + @Nullable private OffsetAwareOutputStream currentStream; + + /** Channel whose segment is currently open; null when no segment is in progress. */ + @Nullable private InputChannelInfo currentChannel; + + @Nullable private FetchedChannelState producedChannelState; + + AbstractSpillingHandler( + InputGate[] inputGates, + InflightDataRescalingDescriptor channelMapping, + String[] spillTmpDirectories, + long maxFileSizeBytes) { + super(inputGates, channelMapping); + checkArgument( + checkNotNull(spillTmpDirectories).length > 0, + "spillTmpDirectories must not be empty"); + checkArgument( + maxFileSizeBytes > 0, "maxFileSizeBytes must be positive: %s", maxFileSizeBytes); + this.spillTmpDirectories = spillTmpDirectories; + this.maxFileSizeBytes = maxFileSizeBytes; + this.baseDir = + Paths.get(spillTmpDirectories[0], "flink-channel-spill-" + UUID.randomUUID()); + } + + /** + * Opens (or switches to) the segment for {@code channelInfo} and returns its buffer for the + * caller to append the body into. The caller must not seal the segment. + */ + DataOutputSerializer segmentSerializerFor(InputChannelInfo channelInfo) throws IOException { + switchChannelIfNeeded(channelInfo); + return segmentSerializer; + } + + private void switchChannelIfNeeded(InputChannelInfo channelInfo) throws IOException { + if (channelInfo.equals(currentChannel)) { + return; + } + if (currentChannel != null) { + sealCurrentSegment(); + } + segmentSerializer.clear(); + segmentSerializer.writeInt(channelInfo.getGateIdx()); + segmentSerializer.writeInt(channelInfo.getInputChannelIdx()); + segmentSerializer.writeInt(0); // bufferLength placeholder + currentChannel = channelInfo; + } + + /** + * Backfills the body length into the segment header and flushes the whole segment to the file + * stream. Empty segments (filtered out entirely, or a zero-byte pass-through) are dropped + * without opening a file, so no empty file is created. + */ + private void sealCurrentSegment() throws IOException { + if (currentChannel == null) { + return; + } + currentChannel = null; + int totalBytes = segmentSerializer.length(); + int bodyBytes = totalBytes - SEGMENT_HEADER_BYTES; + if (bodyBytes == 0) { + return; + } + // Math.toIntExact guards against the unlikely case of a single segment > 2 GB. + segmentSerializer.writeIntUnsafe(Math.toIntExact(bodyBytes), BUFFER_LENGTH_HEADER_OFFSET); + ensureFileOpen(); + currentStream.write(segmentSerializer.getSharedBuffer(), 0, totalBytes); + } + + /** + * Ensures an output stream is ready for the next segment, rotating to a fresh file first if the + * current one reached the size bound. Rotation happens here, between sealed segments, so a + * segment is never split across files. + */ + private void ensureFileOpen() throws IOException { + if (currentStream != null && currentStream.getLength() >= maxFileSizeBytes) { + currentStream.flush(); + currentStream.close(); + currentStream = null; + } + if (currentStream != null) { + return; + } + // create the spill dir on the first file; no-op afterwards + Files.createDirectories(baseDir); + Path filePath = baseDir.resolve("spill-segment-" + files.size() + ".bin"); + currentStream = + new OffsetAwareOutputStream( + new BufferedOutputStream(new FileOutputStream(filePath.toFile())), 0L); + files.add(filePath); + } + + @Override + @Nullable + FetchedChannelState getProducedChannelState() { + return producedChannelState; + } - private final Map rescaledChannels = new HashMap<>(); - private final Map oldToNewMappings = new HashMap<>(); + /** Spill files written so far; empty if this handler never spilled any bytes. */ + @VisibleForTesting + List peekSpillFilesForTesting() { + return files; + } /** - * Optional filtering handler for filtering recovered buffers. When non-null, filtering is - * performed during recovery in the channel-state-unspilling thread. + * Seals the open segment and the file stream, then builds the {@link FetchedChannelState} + * handoff from the written files. Produces nothing if no bytes were ever spilled. */ - @Nullable private final ChannelStateFilteringHandler filteringHandler; + @Override + void closeInternal() throws IOException { + if (currentChannel != null) { + sealCurrentSegment(); + } + if (currentStream != null) { + currentStream.flush(); + currentStream.close(); // OffsetAwareOutputStream closes the wrapped stream quietly + currentStream = null; + } + if (files.isEmpty()) { + return; + } + producedChannelState = new FetchedChannelState(files); + // Keep the files alive between close() and drain-reader construction. + producedChannelState.acquire(); + } +} + +/** + * Recovery handler for the case where checkpointing during recovery is enabled but no filtering + * handler is present. Appends recovered buffer bytes verbatim into the current segment. + */ +class SpillingNoFilteringHandler extends AbstractSpillingHandler { + + SpillingNoFilteringHandler( + InputGate[] inputGates, + InflightDataRescalingDescriptor channelMapping, + String[] spillTmpDirectories) { + super(inputGates, channelMapping, spillTmpDirectories, DEFAULT_SPILL_FILE_SIZE_BYTES); + } + + @Override + public void recover( + InputChannelInfo channelInfo, + int oldSubtaskIndex, + BufferWithContext bufferWithContext) + throws IOException, InterruptedException { + Buffer buffer = bufferWithContext.context; + try { + if (buffer.readableBytes() > 0) { + recoverPassThroughToSpill(getMappedChannels(channelInfo).getChannelInfo(), buffer); + } + } finally { + buffer.recycleBuffer(); + } + } + + private void recoverPassThroughToSpill(InputChannelInfo channelInfo, Buffer source) + throws IOException { + // The recovered bytes are already a length-prefixed record sequence, so append them + // verbatim into the segment without re-framing. Writing straight from the backing + // MemorySegment lets it absorb the heap/off-heap distinction, avoiding both a branch on the + // NIO buffer kind and the intermediate copy a direct buffer would otherwise require. + segmentSerializerFor(channelInfo) + .write( + source.getMemorySegment(), + source.getMemorySegmentOffset() + source.getReaderIndex(), + source.readableBytes()); + } +} + +/** + * Recovery handler for the case where checkpointing during recovery is enabled and a filtering + * handler is present. Uses a reusable heap-backed pre-filter buffer (isolated from the Network + * Buffer Pool) and writes filtered/rewritten output to the spill file via {@link + * ChannelStateFilteringHandler#filterAndRewrite}. + */ +class SpillingWithFilteringHandler extends AbstractSpillingHandler { + + private final ChannelStateFilteringHandler filteringHandler; /** Network buffer memory segment size in bytes. Used to size the reusable pre-filter buffer. */ private final int memorySegmentSize; /** * Reusable heap memory segment backing the pre-filter buffer in filtering mode. Lazily - * allocated on the first {@link #getPreFilterBuffer} call, reused for every subsequent call, - * and freed in {@link #close()}. + * allocated on the first {@link #getBuffer} call, reused for every subsequent call, and freed + * in {@link #closeInternal()}. * *

Reuse is safe because at most one pre-filter buffer is in flight per task at any moment. * This invariant is enforced at runtime by {@link #preFilterBufferInUse}. @@ -108,39 +486,27 @@ class InputChannelRecoveredStateHandler */ private boolean preFilterBufferInUse; - InputChannelRecoveredStateHandler( + SpillingWithFilteringHandler( InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping, - @Nullable ChannelStateFilteringHandler filteringHandler, - int memorySegmentSize) { - this.inputGates = inputGates; - this.channelMapping = channelMapping; + ChannelStateFilteringHandler filteringHandler, + int memorySegmentSize, + String[] spillTmpDirectories) { + super(inputGates, channelMapping, spillTmpDirectories, DEFAULT_SPILL_FILE_SIZE_BYTES); this.filteringHandler = filteringHandler; checkArgument( memorySegmentSize > 0, "memorySegmentSize must be positive: %s", memorySegmentSize); this.memorySegmentSize = memorySegmentSize; } - @Override - public BufferWithContext getBuffer(InputChannelInfo channelInfo) - throws IOException, InterruptedException { - if (filteringHandler != null) { - return getPreFilterBuffer(); - } - // Non-filtering mode: use existing network buffer pool allocation. - RecoveredInputChannel channel = getMappedChannels(channelInfo); - Buffer buffer = channel.requestBufferBlocking(); - return new BufferWithContext<>(wrap(buffer), buffer); - } - /** * Allocates a pre-filter buffer from a reusable heap segment (isolated from the Network Buffer * Pool) in filtering mode. * *

Memory management: a single {@link MemorySegment} per task is lazily allocated on first * invocation and reused across every subsequent call. The custom {@link BufferRecycler} does - * not free the segment — it only flips {@link #preFilterBufferInUse} back to {@code false} so - * the next call can reuse it. The segment itself is freed in {@link #close()}. + * not free the segment; it only flips {@link #preFilterBufferInUse} back to {@code false} so + * the next call can reuse it. The segment itself is freed in {@link #closeInternal()}. * *

Runtime invariant check: the one-at-a-time invariant on pre-filter buffers is guaranteed * by Flink's serial recovery loop and the deserializer's ownership contract. This method @@ -148,7 +514,8 @@ public BufferWithContext getBuffer(InputChannelInfo channelInfo) * recycled, it throws {@link IllegalStateException} so any future regression fails loudly * instead of silently corrupting memory. */ - private BufferWithContext getPreFilterBuffer() { + @Override + public BufferWithContext getBuffer(InputChannelInfo channelInfo) { checkState( !preFilterBufferInUse, "Previous pre-filter buffer has not been recycled. This violates the " @@ -165,17 +532,6 @@ private BufferWithContext getPreFilterBuffer() { return new BufferWithContext<>(wrap(buffer), buffer); } - @VisibleForTesting - boolean isPreFilterBufferInUse() { - return preFilterBufferInUse; - } - - @VisibleForTesting - @Nullable - MemorySegment getPreFilterSegmentForTesting() { - return preFilterSegment; - } - @Override public void recover( InputChannelInfo channelInfo, @@ -185,90 +541,40 @@ public void recover( Buffer buffer = bufferWithContext.context; try { if (buffer.readableBytes() > 0) { - RecoveredInputChannel channel = getMappedChannels(channelInfo); - - if (filteringHandler != null) { - recoverWithFiltering( - channel, channelInfo, oldSubtaskIndex, buffer.retainBuffer()); - } else { - channel.onRecoveredStateBuffer( - EventSerializer.toBuffer( - new SubtaskConnectionDescriptor( - oldSubtaskIndex, channelInfo.getInputChannelIdx()), - false)); - channel.onRecoveredStateBuffer(buffer.retainBuffer()); - } - } - } finally { - buffer.recycleBuffer(); - } - } - - private void recoverWithFiltering( - RecoveredInputChannel channel, - InputChannelInfo channelInfo, - int oldSubtaskIndex, - Buffer retainedBuffer) - throws IOException, InterruptedException { - checkState(filteringHandler != null, "filtering handler not set."); - List filteredBuffers = filteringHandler.filterAndRewrite( channelInfo.getGateIdx(), oldSubtaskIndex, channelInfo.getInputChannelIdx(), - retainedBuffer, - channel::requestBufferBlocking); - - int i = 0; - try { - for (; i < filteredBuffers.size(); i++) { - channel.onRecoveredStateBuffer(filteredBuffers.get(i)); + buffer.retainBuffer(), + segmentSerializerFor(getMappedChannels(channelInfo).getChannelInfo())); } - } catch (Throwable t) { - for (int j = i; j < filteredBuffers.size(); j++) { - filteredBuffers.get(j).recycleBuffer(); - } - throw t; - } - } - - @Override - public void close() throws IOException { - // note that we need to finish all RecoveredInputChannels, not just those with state - for (final InputGate inputGate : inputGates) { - inputGate.finishReadRecoveredState(); - } - if (preFilterSegment != null) { - preFilterSegment.free(); - preFilterSegment = null; - preFilterBufferInUse = false; + } finally { + buffer.recycleBuffer(); } } - private RecoveredInputChannel getChannel(int gateIndex, int subPartitionIndex) { - final InputChannel inputChannel = inputGates[gateIndex].getChannel(subPartitionIndex); - if (!(inputChannel instanceof RecoveredInputChannel)) { - throw new IllegalStateException( - "Cannot restore state to a non-recovered input channel: " + inputChannel); - } - return (RecoveredInputChannel) inputChannel; + @VisibleForTesting + boolean isPreFilterBufferInUse() { + return preFilterBufferInUse; } - private RecoveredInputChannel getMappedChannels(InputChannelInfo channelInfo) { - return rescaledChannels.computeIfAbsent(channelInfo, this::calculateMapping); + @VisibleForTesting + @Nullable + MemorySegment getPreFilterSegmentForTesting() { + return preFilterSegment; } - @Nonnull - private RecoveredInputChannel calculateMapping(InputChannelInfo info) { - final RescaleMappings oldToNewMapping = - oldToNewMappings.computeIfAbsent( - info.getGateIdx(), idx -> channelMapping.getChannelMapping(idx).invert()); - int[] mappedIndexes = oldToNewMapping.getMappedIndexes(info.getInputChannelIdx()); - checkState( - mappedIndexes.length == 1, - "One buffer is only distributed to one target InputChannel since " - + "one buffer is expected to be processed once by the same task."); - return getChannel(info.getGateIdx(), mappedIndexes[0]); + @Override + void closeInternal() throws IOException { + try { + super.closeInternal(); + } finally { + if (preFilterSegment != null) { + preFilterSegment.free(); + preFilterSegment = null; + preFilterBufferInUse = false; + } + } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java index c52572e52faec..161968b6d4739 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java @@ -30,12 +30,15 @@ import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; +import javax.annotation.Nullable; + import java.io.Closeable; import java.io.IOException; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; @@ -53,6 +56,8 @@ public class SequentialChannelStateReaderImpl implements SequentialChannelStateR private final ChannelStateSerializer serializer; private final ChannelStateChunkReader chunkReader; + @Nullable private FetchedChannelState producedChannelState; + public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) { this.taskStateSnapshot = taskStateSnapshot; serializer = new ChannelStateSerializerImpl(); @@ -69,32 +74,46 @@ public void readInputData(InputGate[] inputGates, RecordFilterContext filterCont ? ChannelStateFilteringHandler.createFromContext(filterContext, inputGates) : null; - try (ChannelStateFilteringHandler ignored = filteringHandler; - InputChannelRecoveredStateHandler stateHandler = - new InputChannelRecoveredStateHandler( - inputGates, - taskStateSnapshot.getInputRescalingDescriptor(), - filteringHandler, - filterContext.getMemorySegmentSize())) { - read( - stateHandler, - groupByDelegate( - streamSubtaskStates(), - ChannelStateHelper::extractUnmergedInputHandles)); - read( - stateHandler, - groupByDelegate( - streamSubtaskStates(), - OperatorSubtaskState::getUpstreamOutputBufferState)); + // Manual close ordering so the produced spill file can be published after + // stateHandler.close() flushes the filter writer. + AbstractInputChannelRecoveredStateHandler stateHandler = + AbstractInputChannelRecoveredStateHandler.create( + inputGates, + taskStateSnapshot.getInputRescalingDescriptor(), + filterContext.isCheckpointingDuringRecoveryEnabled(), + filteringHandler, + filterContext.getMemorySegmentSize(), + filterContext.getTmpDirectories()); + try (ChannelStateFilteringHandler ignored = filteringHandler) { + try { + read( + stateHandler, + groupByDelegate( + streamSubtaskStates(), + ChannelStateHelper::extractUnmergedInputHandles)); + read( + stateHandler, + groupByDelegate( + streamSubtaskStates(), + OperatorSubtaskState::getUpstreamOutputBufferState)); - if (filteringHandler != null) { - checkState( - !filteringHandler.hasPartialData(), - "Not all data has been fully consumed during filtering"); + if (filteringHandler != null) { + checkState( + !filteringHandler.hasPartialData(), + "Not all data has been fully consumed during filtering"); + } + } finally { + stateHandler.close(); } + this.producedChannelState = stateHandler.getProducedChannelState(); } } + @Override + public Optional getProducedChannelState() { + return Optional.ofNullable(producedChannelState); + } + @Override public void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) throws IOException, InterruptedException { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/AbstractSpillingHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/AbstractSpillingHandlerTest.java new file mode 100644 index 0000000000000..812bcfebabf91 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/AbstractSpillingHandlerTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.nio.file.Path; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests the on-disk spill format produced by {@link AbstractSpillingHandler}: the 12-byte segment + * header with backfilled buffer length, channel switching, file rotation, and empty-segment + * dropping. Records are appended through {@link TestSpillWriter}, mirroring how the filtering and + * pass-through subclasses feed the segment buffer. + */ +class AbstractSpillingHandlerTest { + + @TempDir Path tempDir; + + @Test + void testChannelSwitchProducesTwoSegmentsInOneFile() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(new InputChannelInfo(0, 0), bytes(0xAA, 0xBB), 2); + writer.writeRecord(new InputChannelInfo(0, 1), bytes(0xCC, 0xDD, 0xEE), 3); + FetchedChannelState state = writer.getChannelState(); + + assertThat(state.files()).hasSize(1); + int seg0Body = Integer.BYTES + 2; + int seg1Body = Integer.BYTES + 3; + assertThat(state.files().get(0).toFile().length()) + .isEqualTo( + (long) (AbstractSpillingHandler.SEGMENT_HEADER_BYTES + seg0Body) + + (AbstractSpillingHandler.SEGMENT_HEADER_BYTES + seg1Body)); + } + } + + @Test + void testSameChannelContinuousRecordsMergeIntoOneSegment() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1), 1); + writer.writeRecord(ch, bytes(2, 3), 2); + writer.writeRecord(ch, bytes(4, 5, 6), 3); + FetchedChannelState state = writer.getChannelState(); + + long expectedBody = 3L * Integer.BYTES + (1 + 2 + 3); + assertThat(state.files()).hasSize(1); + assertThat(state.files().get(0).toFile().length()) + .isEqualTo(AbstractSpillingHandler.SEGMENT_HEADER_BYTES + expectedBody); + } + } + + @Test + void testPassThroughBytesAreWrittenVerbatim() throws Exception { + InputChannelInfo ch = new InputChannelInfo(1, 2); + byte[] data = bytes(0x01, 0x02, 0x03, 0x04, 0x05); + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writePassThrough(ch, data, 0, data.length); + FetchedChannelState state = writer.getChannelState(); + + assertThat(state.files()).hasSize(1); + assertThat(state.files().get(0).toFile().length()) + .isEqualTo(AbstractSpillingHandler.SEGMENT_HEADER_BYTES + data.length); + } + } + + // ------------------------------------------------------------------------------------------- + // Empty segments: opening a segment without writing a body must not create a file + // ------------------------------------------------------------------------------------------- + + @Test + void testOpenSegmentWithoutBodyProducesNoFile() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.openSegment(new InputChannelInfo(0, 0)); + // No body was ever spilled, so no state (and therefore no file) is produced. + assertThat(writer.getChannelState()).isNull(); + } + } + + @Test + void testEmptySegmentsAroundANonEmptyOneProduceExactlyTheNonEmptyFile() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.openSegment(new InputChannelInfo(0, 0)); + writer.writeRecord(new InputChannelInfo(0, 1), bytes(7, 8, 9), 3); + writer.openSegment(new InputChannelInfo(0, 2)); + FetchedChannelState state = writer.getChannelState(); + + assertThat(state.files()).hasSize(1); + assertThat(state.files().get(0).toFile().length()) + .isEqualTo(AbstractSpillingHandler.SEGMENT_HEADER_BYTES + Integer.BYTES + 3); + } + } + + // ------------------------------------------------------------------------------------------- + // File rotation + // ------------------------------------------------------------------------------------------- + + @Test + void testRotationProducesOneFilePerSegmentWhenBoundIsTiny() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir, 1L)) { + writer.writeRecord(new InputChannelInfo(0, 0), bytes(1), 1); + writer.writeRecord(new InputChannelInfo(0, 1), bytes(2, 3), 2); + writer.writeRecord(new InputChannelInfo(0, 2), bytes(4), 1); + FetchedChannelState state = writer.getChannelState(); + + assertThat(state.files()).hasSize(3); + for (Path file : state.files()) { + assertThat(file.toFile().length()).isGreaterThan(0); + } + } + } + + // ------------------------------------------------------------------------------------------- + // Disk-format verification + // ------------------------------------------------------------------------------------------- + + @Test + void testDiskFormatMatchesSpec() throws Exception { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(new InputChannelInfo(7, 3), bytes(0xAB, 0xCD, 0xEF), 3); + Path file = writer.getChannelState().files().get(0); + try (DataInputStream in = new DataInputStream(new FileInputStream(file.toFile()))) { + assertThat(in.readInt()).isEqualTo(7); // gateIdx + assertThat(in.readInt()).isEqualTo(3); // channelIdx + assertThat(in.readInt()).isEqualTo(Integer.BYTES + 3); // bufferLength + assertThat(in.readInt()).isEqualTo(3); // record length prefix + assertThat(in.read()).isEqualTo(0xAB); + assertThat(in.read()).isEqualTo(0xCD); + assertThat(in.read()).isEqualTo(0xEF); + assertThat(in.read()).isEqualTo(-1); // EOF + } + } + } + + private static byte[] bytes(int... values) { + byte[] out = new byte[values.length]; + for (int i = 0; i < values.length; i++) { + out[i] = (byte) values[i]; + } + return out; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandlerTest.java new file mode 100644 index 0000000000000..e9581fa34514d --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandlerTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.RescaleMappings; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; +import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.HashSet; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link ChannelStateFilteringHandler}. */ +class ChannelStateFilteringHandlerTest { + + @TempDir Path tempDir; + + @Test + void testCreateFromContextUsesProvidedSpillDirectories() { + InputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(1).build(); + RecordFilterContext context = createRecordFilterContext(new String[] {tempDir.toString()}); + + ChannelStateFilteringHandler handler = + ChannelStateFilteringHandler.createFromContext( + context, new InputGate[] {inputGate}); + + assertThat(handler).isNotNull(); + handler.close(); + } + + private static RecordFilterContext createRecordFilterContext(String[] tmpDirectories) { + return new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[] { + new RecordFilterContext.InputFilterConfig( + LongSerializer.INSTANCE, new ForwardPartitioner<>(), 1) + }, + new InflightDataRescalingDescriptor( + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor[] { + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor( + new int[] {0}, + RescaleMappings.identity(1, 1), + new HashSet<>(), + InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor + .MappingType.IDENTITY) + }), + 0, + 128, + tmpDirectories, + true, + MemoryManager.DEFAULT_PAGE_SIZE); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java index 85b4fd1d48ef1..ae7b722f683b7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java @@ -38,16 +38,14 @@ import java.io.IOException; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests buffer ownership semantics of {@link ChannelStateFilteringHandler.GateFilterHandler}. Each - * test verifies that buffers are properly recycled on both success and failure paths. + * test verifies that source buffers are properly recycled on both success and failure paths. */ class GateFilterHandlerBufferOwnershipTest { @@ -60,13 +58,9 @@ void testSourceBufferRecycledOnSuccess() throws Exception { createHandler(RecordFilter.acceptAll()); Buffer sourceBuffer = createBufferWithRecords(1L, 2L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + handler.filterAndRewrite(0, 0, sourceBuffer, new DataOutputSerializer(BUFFER_SIZE)); - // sourceBuffer should be recycled by the deserializer after consumption assertThat(sourceBuffer.isRecycled()).isTrue(); - - // Clean up result buffers - result.forEach(Buffer::recycleBuffer); } @Test @@ -75,102 +69,60 @@ void testSourceBufferRecycledWhenAllRecordsFilteredOut() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(rejectAll); Buffer sourceBuffer = createBufferWithRecords(1L, 2L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + handler.filterAndRewrite(0, 0, sourceBuffer, new DataOutputSerializer(BUFFER_SIZE)); - assertThat(result).isEmpty(); - // sourceBuffer should still be recycled even though no output was produced assertThat(sourceBuffer.isRecycled()).isTrue(); } @Test void testSourceBufferRecycledOnInvalidVirtualChannel() { - // Create handler with KEY=(0,0) but call with (1,1) to trigger IllegalStateException + // Create handler with KEY=(0,0) but call with (1,1) to trigger IllegalStateException. ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); Buffer sourceBuffer = createBufferWithRecords(1L); assertThatThrownBy( - () -> handler.filterAndRewrite(1, 1, sourceBuffer, this::createEmptyBuffer)) + () -> + handler.filterAndRewrite( + 1, 1, sourceBuffer, new DataOutputSerializer(BUFFER_SIZE))) .isInstanceOf(IllegalStateException.class); - // sourceBuffer must be recycled even when lookup fails before setNextBuffer - assertThat(sourceBuffer.isRecycled()).isTrue(); - } - - @Test - void testResultBuffersAndCurrentBufferRecycledOnSerializationError() throws Exception { - // Use a small buffer so that records span multiple buffers. The supplier fails on the - // second request, after the first output buffer has been filled and added to resultBuffers. - AtomicInteger bufferRequestCount = new AtomicInteger(0); - ChannelStateFilteringHandler.BufferSupplier failingSupplier = - () -> { - if (bufferRequestCount.incrementAndGet() > 1) { - throw new IOException("Simulated buffer allocation failure"); - } - return createEmptyBuffer(13); - }; - - ChannelStateFilteringHandler.GateFilterHandler handler = - createHandler(RecordFilter.acceptAll()); - - Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); - - // The exception should propagate; no buffer leak (no IllegalReferenceCountException - // from double-recycle). - assertThatThrownBy(() -> handler.filterAndRewrite(0, 0, sourceBuffer, failingSupplier)) - .isInstanceOf(IOException.class) - .hasMessage("Simulated buffer allocation failure"); - - // sourceBuffer ownership was transferred to the deserializer via setNextBuffer(). - // The deserializer may still hold it if it hasn't fully consumed the buffer before the - // error. Calling clear() triggers the cleanup chain: - // GateFilterHandler#clear() -> VirtualChannel#clear() -> deserializer.clear() - handler.clear(); + // sourceBuffer must be recycled even when lookup fails before setNextBuffer. assertThat(sourceBuffer.isRecycled()).isTrue(); } /** - * Tests the production cleanup path: when filterAndRewrite throws mid-processing, the - * deserializer may still hold sourceBuffer. In production, ChannelStateFilteringHandler is used - * in a try-with-resources block (see {@code SequentialChannelStateReaderImpl#readInputData}), - * so its close() is guaranteed to be called, which triggers clear() on all GateFilterHandlers - * and their deserializers. This test simulates that exact pattern. + * When filterAndRewrite throws mid-processing, the deserializer may still hold sourceBuffer. In + * production, ChannelStateFilteringHandler is used in a try-with-resources block (see {@code + * SequentialChannelStateReaderImpl#readInputData}), so its close() is guaranteed to be called, + * which triggers clear() on all GateFilterHandlers and their deserializers. This test simulates + * that exact pattern. */ @Test void testCloseRecyclesDeserializerHeldBufferAfterError() throws Exception { - AtomicInteger bufferRequestCount = new AtomicInteger(0); - ChannelStateFilteringHandler.BufferSupplier failingSupplier = - () -> { - if (bufferRequestCount.incrementAndGet() > 1) { - throw new IOException("Simulated buffer allocation failure"); - } - return createEmptyBuffer(13); - }; - ChannelStateFilteringHandler.GateFilterHandler gateHandler = createHandler(RecordFilter.acceptAll()); - // Wrap in ChannelStateFilteringHandler, the production-level owner ChannelStateFilteringHandler filteringHandler = new ChannelStateFilteringHandler( new ChannelStateFilteringHandler.GateFilterHandler[] {gateHandler}); + // A serializer that throws while writing the second record's length prefix, triggering a + // mid-processing failure after the first record has already been emitted. + DataOutputSerializer failingSerializer = new FailingAfterFirstRecordSerializer(); Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); - // Simulate the production try-with-resources pattern assertThatThrownBy( () -> { try (ChannelStateFilteringHandler ignored = filteringHandler) { filteringHandler.filterAndRewrite( - 0, 0, 0, sourceBuffer, failingSupplier); + 0, 0, 0, sourceBuffer, failingSerializer); } }) .isInstanceOf(IOException.class) - .hasMessage("Simulated buffer allocation failure"); + .hasMessage("Simulated write failure"); - // After close(), the entire cleanup chain has fired: - // ChannelStateFilteringHandler.close() -> GateFilterHandler.clear() - // -> VirtualChannel.clear() -> deserializer.clear() -> sourceBuffer.recycleBuffer() + // After close(), the entire cleanup chain has fired. assertThat(sourceBuffer.isRecycled()).isTrue(); } @@ -219,12 +171,26 @@ private Buffer createBufferWithRecords(Long... values) { } } - private Buffer createEmptyBuffer() { - return createEmptyBuffer(BUFFER_SIZE); - } + /** + * A {@link DataOutputSerializer} that throws an IOException while writing the second record's + * length prefix, simulating a failure mid-stream to verify that the source buffer is still + * recycled via the filtering handler's close() cleanup chain. Each surviving record begins with + * a {@code writeInt} placeholder for its length, so the second {@code writeInt} marks the start + * of the second record. + */ + private static final class FailingAfterFirstRecordSerializer extends DataOutputSerializer { + private int writeIntCount = 0; + + FailingAfterFirstRecordSerializer() { + super(BUFFER_SIZE); + } - private Buffer createEmptyBuffer(int size) { - MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(size); - return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + @Override + public void writeInt(int v) throws IOException { + if (++writeIntCount > 1) { + throw new IOException("Simulated write failure"); + } + super.writeInt(v); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java index f02ce35fd867d..1646908727a88 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java @@ -57,10 +57,10 @@ void testAllRecordsPassFilter() throws Exception { createHandler(RecordFilter.acceptAll()); Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); - // deserializeBuffers consumes (recycles) each buffer via the deserializer - List values = deserializeBuffers(result); + List values = readRecordsFromSerializer(output); assertThat(values).containsExactly(1L, 2L, 3L); } @@ -70,9 +70,11 @@ void testAllRecordsFilteredOut() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(rejectAll); Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); - assertThat(result).isEmpty(); + // No bytes should be written when all records are filtered out. + assertThat(output.length()).isZero(); } @Test @@ -81,42 +83,50 @@ void testPartialFiltering() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(keepEven); Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); - List result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); - List values = deserializeBuffers(result); + List values = readRecordsFromSerializer(output); assertThat(values).containsExactly(2L, 4L); } @Test - void testSmallOutputBufferProducesMultipleBuffers() throws Exception { - // Use a very small output buffer size so records must span multiple buffers - int smallBufferSize = 8; + void testEmptyBuffer() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); - Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); - List result = - handler.filterAndRewrite( - 0, 0, sourceBuffer, () -> createEmptyBuffer(smallBufferSize)); + Buffer emptyBuffer = createEmptyBuffer(); + emptyBuffer.setSize(0); - // Each Long record needs 4 bytes length + ~9 bytes data > 8-byte buffer - assertThat(result.size()).isGreaterThan(1); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, emptyBuffer, output); - List values = deserializeBuffers(result); - assertThat(values).containsExactly(1L, 2L, 3L); + // No data written for an empty source buffer. + assertThat(output.length()).isZero(); } @Test - void testEmptyBuffer() throws Exception { + void testSourceBufferRecycledOnSuccess() throws Exception { ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(RecordFilter.acceptAll()); - Buffer emptyBuffer = createEmptyBuffer(); - emptyBuffer.setSize(0); + Buffer sourceBuffer = createBufferWithRecords(1L, 2L); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); + + assertThat(sourceBuffer.isRecycled()).isTrue(); + } + + @Test + void testSourceBufferRecycledWhenAllRecordsFilteredOut() throws Exception { + RecordFilter rejectAll = record -> false; + ChannelStateFilteringHandler.GateFilterHandler handler = createHandler(rejectAll); - List result = handler.filterAndRewrite(0, 0, emptyBuffer, this::createEmptyBuffer); + Buffer sourceBuffer = createBufferWithRecords(1L, 2L); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + handler.filterAndRewrite(0, 0, sourceBuffer, output); - assertThat(result).isEmpty(); + assertThat(sourceBuffer.isRecycled()).isTrue(); } // ------------------------------------------------------------------------------------------- @@ -141,23 +151,13 @@ private ChannelStateFilteringHandler.GateFilterHandler createHandler( private Buffer createBufferWithRecords(Long... values) throws IOException { StreamElementSerializer serializer = new StreamElementSerializer<>(LongSerializer.INSTANCE); - return serializeRecordsToBuffer(serializer, values); - } - - /** Serializes records into a buffer using Flink's length-prefixed format. */ - private Buffer serializeRecordsToBuffer( - StreamElementSerializer serializer, Long... values) throws IOException { DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); for (Long value : values) { - // Serialize using the same length-prefixed format as Flink DataOutputSerializer recordOutput = new DataOutputSerializer(64); serializer.serialize(new StreamRecord<>(value), recordOutput); int recordLength = recordOutput.length(); - - // Write 4-byte big-endian length prefix output.writeInt(recordLength); - // Write record bytes output.write(recordOutput.getSharedBuffer(), 0, recordLength); } @@ -171,43 +171,49 @@ private Buffer serializeRecordsToBuffer( } private Buffer createEmptyBuffer() { - return createEmptyBuffer(BUFFER_SIZE); - } - - private Buffer createEmptyBuffer(int size) { - MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(size); + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE); return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); } - private List deserializeBuffers(List buffers) throws IOException { + /** + * Deserializes the records the handler appended into {@code output}. The body format is + * repeated (4B recordLen + N bytes of serialized StreamElement), which the deserializer reads + * directly. + */ + private List readRecordsFromSerializer(DataOutputSerializer output) throws Exception { + List values = new ArrayList<>(); StreamElementSerializer serializer = new StreamElementSerializer<>(LongSerializer.INSTANCE); + DeserializationDelegate delegate = + new NonReusingDeserializationDelegate<>(serializer); + + byte[] bodyBytes = output.getCopyOfBuffer(); + if (bodyBytes.length == 0) { + return values; + } + MemorySegment memSeg = MemorySegmentFactory.allocateUnpooledSegment(bodyBytes.length); + memSeg.put(0, bodyBytes); + NetworkBuffer buf = new NetworkBuffer(memSeg, FreeingBufferRecycler.INSTANCE); + buf.setSize(bodyBytes.length); + SpillingAdaptiveSpanningRecordDeserializer> deserializer = new SpillingAdaptiveSpanningRecordDeserializer<>( new String[] {System.getProperty("java.io.tmpdir")}); - DeserializationDelegate delegate = - new NonReusingDeserializationDelegate<>(serializer); - - List values = new ArrayList<>(); - for (Buffer buffer : buffers) { - deserializer.setNextBuffer(buffer); - while (true) { - RecordDeserializer.DeserializationResult result = - deserializer.getNextRecord(delegate); - if (result.isFullRecord()) { - StreamElement element = delegate.getInstance(); - if (element.isRecord()) { - @SuppressWarnings("unchecked") - StreamRecord record = (StreamRecord) element; - values.add(record.getValue()); - } - } - if (result.isBufferConsumed()) { - break; + deserializer.setNextBuffer(buf); + + RecordDeserializer.DeserializationResult result; + do { + result = deserializer.getNextRecord(delegate); + if (result.isFullRecord()) { + StreamElement element = delegate.getInstance(); + if (element.isRecord()) { + @SuppressWarnings("unchecked") + StreamRecord record = (StreamRecord) element; + values.add(record.getValue()); } } - } + } while (!result.isBufferConsumed()); return values; } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java index 9c4aab0bc7a5d..d82fb1866a688 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java @@ -32,7 +32,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import java.nio.file.Path; import java.util.HashSet; import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.mappings; @@ -40,12 +42,15 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** Test of different implementation of {@link InputChannelRecoveredStateHandler}. */ +/** Test of different implementation of {@link AbstractInputChannelRecoveredStateHandler}. */ class InputChannelRecoveredStateHandlerTest extends RecoveredChannelStateHandlerTest { + @TempDir private Path tmpDir; + private static final int preAllocatedSegments = 3; private NetworkBufferPool networkBufferPool; private SingleInputGate inputGate; - private InputChannelRecoveredStateHandler icsHandler; + // NoSpillingHandler: checkpointingDuringRecoveryEnabled=false, filteringHandler=null + private AbstractInputChannelRecoveredStateHandler icsHandler; private InputChannelInfo channelInfo; @BeforeEach @@ -61,14 +66,14 @@ void setUp() { .setSegmentProvider(networkBufferPool) .build(); - icsHandler = buildInputChannelStateHandler(inputGate); + icsHandler = buildNoSpillingHandler(inputGate); channelInfo = new InputChannelInfo(0, 0); } - private InputChannelRecoveredStateHandler buildInputChannelStateHandler( + private AbstractInputChannelRecoveredStateHandler buildNoSpillingHandler( SingleInputGate inputGate) { - return new InputChannelRecoveredStateHandler( + return AbstractInputChannelRecoveredStateHandler.create( new InputGate[] {inputGate}, new InflightDataRescalingDescriptor( new InflightDataRescalingDescriptor @@ -82,11 +87,13 @@ private InputChannelRecoveredStateHandler buildInputChannelStateHandler( .InflightDataGateOrPartitionRescalingDescriptor .MappingType.IDENTITY) }), + false, null, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + null); } - private InputChannelRecoveredStateHandler buildMultiChannelHandler() { + private AbstractInputChannelRecoveredStateHandler buildMultiChannelHandler() { // Setup multi-channel scenario to trigger distribution constraint validation SingleInputGate multiChannelGate = new SingleInputGateBuilder() @@ -95,7 +102,7 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() { .setSegmentProvider(networkBufferPool) .build(); - return new InputChannelRecoveredStateHandler( + return AbstractInputChannelRecoveredStateHandler.create( new InputGate[] {multiChannelGate}, new InflightDataRescalingDescriptor( new InflightDataRescalingDescriptor @@ -110,18 +117,43 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() { .InflightDataGateOrPartitionRescalingDescriptor .MappingType.RESCALING) }), + false, null, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + null); } /** Builds a handler in filtering mode (non-null filtering handler, no-op stub). */ - private InputChannelRecoveredStateHandler buildFilteringInputChannelStateHandler() { + private SpillingWithFilteringHandler buildFilteringInputChannelStateHandler() { // Empty GateFilterHandler array: filtering is "enabled" structurally, but no gate-level // filter logic runs. Suitable for exercising getBuffer() routing only. ChannelStateFilteringHandler stubFilteringHandler = new ChannelStateFilteringHandler( new ChannelStateFilteringHandler.GateFilterHandler[0]); - return new InputChannelRecoveredStateHandler( + return (SpillingWithFilteringHandler) + AbstractInputChannelRecoveredStateHandler.create( + new InputGate[] {inputGate}, + new InflightDataRescalingDescriptor( + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor[] { + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor( + new int[] {1}, + RescaleMappings.identity(1, 1), + new HashSet<>(), + InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor + .MappingType.IDENTITY) + }), + true, + stubFilteringHandler, + MemoryManager.DEFAULT_PAGE_SIZE, + new String[] {tmpDir.toAbsolutePath().toString()}); + } + + private AbstractInputChannelRecoveredStateHandler buildSpillingNoFilteringHandler( + String[] spillTmpDirectories) { + return AbstractInputChannelRecoveredStateHandler.create( new InputGate[] {inputGate}, new InflightDataRescalingDescriptor( new InflightDataRescalingDescriptor @@ -135,14 +167,16 @@ private InputChannelRecoveredStateHandler buildFilteringInputChannelStateHandler .InflightDataGateOrPartitionRescalingDescriptor .MappingType.IDENTITY) }), - stubFilteringHandler, - MemoryManager.DEFAULT_PAGE_SIZE); + true, + null, + MemoryManager.DEFAULT_PAGE_SIZE, + spillTmpDirectories); } @Test void testBufferDistributedToMultipleInputChannelsThrowsException() throws Exception { // Test constraint that prevents buffer distribution to multiple channels - try (InputChannelRecoveredStateHandler handler = buildMultiChannelHandler()) { + try (AbstractInputChannelRecoveredStateHandler handler = buildMultiChannelHandler()) { assertThatThrownBy(() -> handler.getBuffer(channelInfo)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining( @@ -187,7 +221,7 @@ void testRecycleBufferAfterRecoverWasCalled() throws Exception { @Test void testPreFilterBufferIsolationFromNetworkBufferPool() throws Exception { - try (InputChannelRecoveredStateHandler filteringHandler = + try (SpillingWithFilteringHandler filteringHandler = buildFilteringInputChannelStateHandler()) { int availableBefore = networkBufferPool.getNumberOfAvailableMemorySegments(); @@ -229,7 +263,7 @@ void testNonFilteringModeUsesNetworkBufferPool() throws Exception { @Test void testPreFilterSegmentReusedAcrossCalls() throws Exception { - try (InputChannelRecoveredStateHandler filteringHandler = + try (SpillingWithFilteringHandler filteringHandler = buildFilteringInputChannelStateHandler()) { // First getBuffer() lazily allocates the segment. RecoveredChannelStateHandler.BufferWithContext first = @@ -259,7 +293,7 @@ void testPreFilterSegmentReusedAcrossCalls() throws Exception { @Test void testGetBufferThrowsWhenPriorBufferNotRecycled() throws Exception { - try (InputChannelRecoveredStateHandler filteringHandler = + try (SpillingWithFilteringHandler filteringHandler = buildFilteringInputChannelStateHandler()) { RecoveredChannelStateHandler.BufferWithContext first = filteringHandler.getBuffer(channelInfo); @@ -283,8 +317,7 @@ void testGetBufferThrowsWhenPriorBufferNotRecycled() throws Exception { @Test void testPreFilterSegmentFreedOnClose() throws Exception { - InputChannelRecoveredStateHandler filteringHandler = - buildFilteringInputChannelStateHandler(); + SpillingWithFilteringHandler filteringHandler = buildFilteringInputChannelStateHandler(); RecoveredChannelStateHandler.BufferWithContext bufferWithContext = filteringHandler.getBuffer(channelInfo); bufferWithContext.context.recycleBuffer(); @@ -298,4 +331,13 @@ void testPreFilterSegmentFreedOnClose() throws Exception { assertThat(segment.isFreed()).isTrue(); assertThat(filteringHandler.getPreFilterSegmentForTesting()).isNull(); } + + @Test + void testSpillingHandlerRequiresSpillDirectories() { + assertThatThrownBy(() -> buildSpillingNoFilteringHandler(null)) + .isInstanceOf(NullPointerException.class); + assertThatThrownBy(() -> buildSpillingNoFilteringHandler(new String[0])) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("spillTmpDirectories must not be empty"); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerFilterRoutingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerFilterRoutingTest.java new file mode 100644 index 0000000000000..493e88026dbee --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerFilterRoutingTest.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.RescaleMappings; +import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; +import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.plugable.DeserializationDelegate; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilter; +import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; + +class RecoveredChannelStateHandlerFilterRoutingTest { + + @TempDir Path tempDir; + + private NetworkBufferPool networkBufferPool; + private SingleInputGate inputGate; + private InputChannelInfo channelInfo; + + @BeforeEach + void setUp() { + // Plenty of segments so the filter-on path does not deadlock on the bounded pool. + networkBufferPool = new NetworkBufferPool(64, MemoryManager.DEFAULT_PAGE_SIZE); + inputGate = + new SingleInputGateBuilder() + .setChannelFactory(InputChannelBuilder::buildLocalRecoveredChannel) + .setSegmentProvider(networkBufferPool) + .build(); + channelInfo = new InputChannelInfo(0, 0); + } + + @AfterEach + void tearDown() throws Exception { + inputGate.close(); + networkBufferPool.destroy(); + } + + @Test + void testFilterOnRoutesOutputToChannelState() throws Exception { + ChannelStateFilteringHandler filteringHandler = newPassThroughFilteringHandler(); + try (ChannelStateFilteringHandler ignored = filteringHandler) { + SpillingWithFilteringHandler handler = newFilterOnHandler(filteringHandler); + invokeRecoverWithRecords(handler, 1L, 2L, 3L); + + // Surviving records accumulate in the in-memory segment serializer; they are only + // sealed and flushed to a spill file on channel switch or close. With a single channel + // and no switch, close() is what seals the segment, so the assertion must follow it. + handler.close(); + + assertThat(handler.peekSpillFilesForTesting()) + .as("filter-on path must spill the surviving records to a file") + .isNotEmpty(); + } + } + + @Test + void testFilterOnAccumulatorBuffersComeFromHeapNotPool() throws Exception { + // The accumulator's prefilter + postfilter buffers are unpooled heap segments owned by + // the handler — invoking filter recovery must NOT consume any network buffer pool + // segments for the accumulator path. Channel-side exclusive buffer reservation, if any, + // happens at handler construction and is already reflected in the pre-recover snapshot. + ChannelStateFilteringHandler filteringHandler = newPassThroughFilteringHandler(); + SpillingWithFilteringHandler handler = newFilterOnHandler(filteringHandler); + try (ChannelStateFilteringHandler ignored = filteringHandler) { + int availableBeforeRecover = networkBufferPool.getNumberOfAvailableMemorySegments(); + + invokeRecoverWithRecords(handler, 1L, 2L, 3L); + + int availableAfterRecover = networkBufferPool.getNumberOfAvailableMemorySegments(); + assertThat(availableAfterRecover) + .as("filter accumulator buffers must not be sourced from the network pool") + .isEqualTo(availableBeforeRecover); + + handler.close(); + inputGate.close(); + assertThat(networkBufferPool.getNumberOfAvailableMemorySegments()) + .as("pool count after close must match pre-recover (filter took nothing)") + .isEqualTo(availableBeforeRecover); + } + } + + @Test + void testFilterOnDoesNotInvokeChannelOnRecoveredStateBuffer() throws Exception { + ChannelStateFilteringHandler filteringHandler = newPassThroughFilteringHandler(); + try (ChannelStateFilteringHandler ignored = filteringHandler; + SpillingWithFilteringHandler handler = newFilterOnHandler(filteringHandler)) { + invokeRecoverWithRecords(handler, 1L, 2L, 3L); + + int queuedDuringRecovery = countQueuedRecoveredBuffers(); + assertThat(queuedDuringRecovery) + .as("filter-on must not enqueue buffers into the channel during recovery") + .isEqualTo(0); + } + } + + @Test + void testFilterOffDoesNotCreateChannelState() throws Exception { + try (NoSpillingHandler handler = newFilterOffHandler()) { + invokeRecoverWithRawBytes(handler, new byte[] {1, 2, 3, 4}); + + // NoSpillingHandler has no channelStateWriter, so peek always returns null. + assertThat(handler.getProducedChannelState()) + .as("filter-off close() must not publish a FetchedChannelState either") + .isNull(); + } + } + + @Test + void testFilterOffMaintainsMasterBehavior() throws Exception { + try (NoSpillingHandler handler = newFilterOffHandler()) { + invokeRecoverWithRawBytes(handler, new byte[] {1, 2, 3, 4}); + + // Filter-off path enqueues the SubtaskConnectionDescriptor event plus the data + // buffer directly into the channel's recoveredBuffers. + int queued = countQueuedRecoveredBuffers(); + assertThat(queued) + .as("filter-off must enqueue the descriptor + data buffer into the channel") + .isGreaterThanOrEqualTo(2); + } + } + + private SpillingWithFilteringHandler newFilterOnHandler( + ChannelStateFilteringHandler filteringHandler) { + return (SpillingWithFilteringHandler) + AbstractInputChannelRecoveredStateHandler.create( + new InputGate[] {inputGate}, + identityRescalingForOneGate(), + true, + filteringHandler, + MemoryManager.DEFAULT_PAGE_SIZE, + new String[] {tempDir.toString()}); + } + + private NoSpillingHandler newFilterOffHandler() { + return (NoSpillingHandler) + AbstractInputChannelRecoveredStateHandler.create( + new InputGate[] {inputGate}, + identityRescalingForOneGate(), + false, + null, + MemoryManager.DEFAULT_PAGE_SIZE, + null); + } + + private static InflightDataRescalingDescriptor identityRescalingForOneGate() { + return new InflightDataRescalingDescriptor( + new InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor + [] { + new InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor( + new int[] {1}, + RescaleMappings.identity(1, 1), + new HashSet<>(), + InflightDataRescalingDescriptor + .InflightDataGateOrPartitionRescalingDescriptor.MappingType + .IDENTITY) + }); + } + + private ChannelStateFilteringHandler newPassThroughFilteringHandler() { + StreamElementSerializer serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + RecordDeserializer> deserializer = + new SpillingAdaptiveSpanningRecordDeserializer<>( + new String[] {System.getProperty("java.io.tmpdir")}); + VirtualChannel vc = new VirtualChannel<>(deserializer, RecordFilter.acceptAll()); + Map> channels = new HashMap<>(); + // The handler will invoke filterAndRewrite with oldSubtaskIndex=1 — keep the key aligned. + channels.put(new SubtaskConnectionDescriptor(1, channelInfo.getInputChannelIdx()), vc); + + ChannelStateFilteringHandler.GateFilterHandler gateHandler = + new ChannelStateFilteringHandler.GateFilterHandler<>(channels, serializer); + return new ChannelStateFilteringHandler( + new ChannelStateFilteringHandler.GateFilterHandler[] {gateHandler}); + } + + private void invokeRecoverWithRecords( + AbstractInputChannelRecoveredStateHandler handler, Long... values) throws Exception { + Buffer source = createRecordBuffer(values); + invokeRecoverWithBuffer(handler, source); + } + + private void invokeRecoverWithRawBytes( + AbstractInputChannelRecoveredStateHandler handler, byte[] data) throws Exception { + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(data.length); + seg.put(0, data); + NetworkBuffer source = new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + source.setSize(data.length); + invokeRecoverWithBuffer(handler, source); + } + + /** + * Mirrors the chunkReader's getBuffer + recover sequence: handler-issued buffer is filled with + * the source data, then handed back to recover. + */ + private void invokeRecoverWithBuffer( + AbstractInputChannelRecoveredStateHandler handler, Buffer source) throws Exception { + RecoveredChannelStateHandler.BufferWithContext bwc = handler.getBuffer(channelInfo); + try { + Buffer dest = bwc.context; + int len = source.readableBytes(); + source.getMemorySegment() + .copyTo( + source.getMemorySegmentOffset(), + dest.getMemorySegment(), + dest.getMemorySegmentOffset(), + len); + dest.setSize(len); + } finally { + source.recycleBuffer(); + } + // oldSubtaskIndex=1 matches the pass-through filter's virtual channel key. The filter-off + // path ignores this argument's mapping (it only flows into a SubtaskConnectionDescriptor). + handler.recover(channelInfo, 1, bwc); + } + + private static Buffer createRecordBuffer(Long... values) throws IOException { + StreamElementSerializer serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + DataOutputSerializer output = new DataOutputSerializer(256); + for (Long v : values) { + DataOutputSerializer rec = new DataOutputSerializer(64); + serializer.serialize(new StreamRecord<>(v), rec); + int recordLength = rec.length(); + output.writeInt(recordLength); + output.write(rec.getSharedBuffer(), 0, recordLength); + } + byte[] data = output.getCopyOfBuffer(); + MemorySegment segment = + MemorySegmentFactory.allocateUnpooledSegment(MemoryManager.DEFAULT_PAGE_SIZE); + segment.put(0, data, 0, data.length); + NetworkBuffer buf = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + buf.setSize(data.length); + return buf; + } + + /** + * Counts the number of buffers currently queued in the only recovered input channel by draining + * via the public {@code getNextBuffer()} entry point. After this helper returns the channel + * queue is empty by definition; tests should not call it twice expecting the same answer. + */ + private int countQueuedRecoveredBuffers() throws IOException { + RecoveredInputChannel ch = (RecoveredInputChannel) inputGate.getChannel(0); + int count = 0; + while (true) { + Optional next = ch.getNextBuffer(); + if (!next.isPresent()) { + break; + } + count++; + next.get().buffer().recycleBuffer(); + if (count > 1000) { + throw new IllegalStateException("Unexpected unbounded queue contents"); + } + } + return count; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java index 01f7c43920ae9..db302d23f7337 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandlerTest.java @@ -22,7 +22,7 @@ /** * Base class which contains all tests which should be implemented for every implementation of - * {@link InputChannelRecoveredStateHandler}. + * {@link AbstractInputChannelRecoveredStateHandler}. */ abstract class RecoveredChannelStateHandlerTest { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java index d80442b8a06f8..b3dedda0c72b9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java @@ -51,10 +51,12 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -90,6 +92,8 @@ public static List parameters() { new Object[] {"ReadPermutedStateWithIncreasedBuffer", 10, 10, 10, 10, 20}); } + @TempDir static Path tmpDir; + @Parameter public String desc; @Parameter(value = 1) @@ -144,7 +148,10 @@ void testReadPermutedState() throws Exception { withInputGates( gates -> { - reader.readInputData(gates, RecordFilterContext.disabled()); + reader.readInputData( + gates, + RecordFilterContext.disabled( + new String[] {tmpDir.toAbsolutePath().toString()})); assertBuffersEquals(inputChannelsData, collectBuffers(gates)); assertConsumed(gates); }); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestSpillWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestSpillWriter.java new file mode 100644 index 0000000000000..4b3f1b9aa6496 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/TestSpillWriter.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.file.Path; + +/** + * Test-only helper that produces spill files in the {@link AbstractSpillingHandler} on-disk format, + * exposing the same {@code writeRecord} / {@code writePassThrough} surface the readers and drainers + * are tested against. It drives a minimal concrete {@link AbstractSpillingHandler} so tests need + * not stand up real input gates or run the recovery loop. + */ +final class TestSpillWriter implements Closeable { + + private final FormatHandler handler; + + TestSpillWriter(Path baseDir) { + this(baseDir, AbstractSpillingHandler.DEFAULT_SPILL_FILE_SIZE_BYTES); + } + + TestSpillWriter(Path baseDir, long maxFileSizeBytes) { + this.handler = new FormatHandler(new String[] {baseDir.toString()}, maxFileSizeBytes); + } + + /** Appends one length-prefixed record, mirroring the filtering path. */ + void writeRecord(InputChannelInfo channelInfo, byte[] record, int recordLength) + throws IOException { + DataOutputSerializer segment = handler.segmentSerializerFor(channelInfo); + segment.writeInt(recordLength); + segment.write(record, 0, recordLength); + } + + /** Appends verbatim bytes, mirroring the pass-through path. */ + void writePassThrough(InputChannelInfo channelInfo, byte[] data, int offset, int length) + throws IOException { + handler.segmentSerializerFor(channelInfo).write(data, offset, length); + } + + /** Opens (or switches to) a segment without writing any body, to exercise empty segments. */ + void openSegment(InputChannelInfo channelInfo) throws IOException { + handler.segmentSerializerFor(channelInfo); + } + + /** + * Seals the spilled segments and returns the produced state, already holding one lifecycle + * grant for the caller. Returns {@code null} if nothing was ever written. + */ + FetchedChannelState getChannelState() throws IOException { + handler.close(); + return handler.getProducedChannelState(); + } + + @Override + public void close() throws IOException { + handler.close(); + } + + private static final class FormatHandler extends AbstractSpillingHandler { + + FormatHandler(String[] spillTmpDirectories, long maxFileSizeBytes) { + super( + new InputGate[0], + InflightDataRescalingDescriptor.NO_RESCALE, + spillTmpDirectories, + maxFileSizeBytes); + } + + @Override + public void recover( + InputChannelInfo info, + int oldSubtaskIndex, + BufferWithContext ctx) { + throw new UnsupportedOperationException("not used in format tests"); + } + } +} From abe52fed9fae09b076f23394c392e5f3542d22a8 Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Fri, 22 May 2026 13:28:25 +0200 Subject: [PATCH 05/16] [FLINK-38544][checkpoint] Phase 4: spill reader drain + heap fallback removal - New SpillFileReader implements RecoveryCheckpointTrigger + Closeable. drain(): buffer alloc + disk read outside lock; deliver + offset advance inside lock. snapshotAndInsertBarriers(cpId): atomic startPos snapshot + per-channel barrier insert. Constructor derives the InputChannelInfo map internally; bodies pair acquire/release against SpillFile's ref counter. - New RecoveredChannelBufferRequester delegates to RecoveredInputChannel pool. - RecoveredInputChannel.requestBufferBlocking heap fallback removed (no more MemorySegmentFactory.allocateUnpooledSegment; OOM path eliminated). - channelIOExecutor wired: filter-on submits drain after conversion completes; exceptions bubble via StreamTask.asyncExceptionHandler. Design: requirements/38544/phase4_spill_reader/design.md (cherry picked from commit 1315d38a5196249d017e3cd2d86ba54865934524) --- .../channel/FetchedChannelState.java | 149 +++++ .../channel/FetchedChannelStateDrainer.java | 202 +++++++ .../channel/FetchedChannelStateReader.java | 127 +++++ .../FetchedChannelStateReaderImpl.java | 533 ++++++++++++++++++ .../channel/FetchedChannelStateSnapshot.java | 91 +++ .../channel/SequentialChannelStateReader.java | 9 + .../network/logger/NetworkActionsLogger.java | 12 +- .../EndOfFetchedChannelStateEvent.java | 75 +++ .../consumer/RecoveredInputChannel.java | 96 +--- .../io/AbstractStreamTaskNetworkInput.java | 3 +- .../checkpointing/CheckpointedInputGate.java | 15 + .../io/recovery/RecordFilterContext.java | 24 +- .../streaming/runtime/tasks/StreamTask.java | 238 +++++++- .../ChannelIOExecutorDrainSubmissionTest.java | 204 +++++++ ...hedChannelStateDrainerConcurrencyTest.java | 184 ++++++ .../FetchedChannelStateDrainerTest.java | 475 ++++++++++++++++ .../FetchedChannelStateReaderTest.java | 532 +++++++++++++++++ .../FetchedChannelStateRefCountTest.java | 153 +++++ .../channel/FetchedChannelStateTest.java | 118 ++++ ...BufferBlockingHeapFallbackRemovedTest.java | 140 +++++ .../state/ChannelPersistenceITCase.java | 9 +- .../io/recovery/RecordFilterContextTest.java | 52 +- 22 files changed, 3300 insertions(+), 141 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReader.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderImpl.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateSnapshot.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EndOfFetchedChannelStateEvent.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateRefCountTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java new file mode 100644 index 0000000000000..68116cad9fec4 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Sealed container for recovered channel-state data written to spill files. + * + *

Holds a list of file paths (in write order). Segment boundaries are self-described in disk + * segment headers ([4B gateIdx][4B channelIdx][4B bufferLength]), so no in-memory segment locator + * table is maintained. The reader scans files sequentially, reading each 12-byte header to obtain + * the channel info and body length. + * + *

The file list grows as the writer rotates to new files (one rotation per 64 MB soft limit), + * and is sealed on writer close. + * + *

File lifecycle is managed by {@link #acquire()} / {@link #release()} reference counting. Files + * are deleted only when the last lifecycle grant is released (i.e. when both the drain reader and + * all snapshot readers have finished). + * + *

Mutations (file list appends) are single-writer and intentionally unsynchronized; callers must + * serialize them via the channel IO executor. + */ +@Internal +public final class FetchedChannelState implements Closeable { + + /** Ordered list of spill file paths, one entry per physical file. Sealed at construction. */ + private final List files; + + // close() and release() may be called from different threads; volatile ensures visibility. + private volatile boolean closed = false; + + private final AtomicInteger refCount = new AtomicInteger(0); + + private final AtomicBoolean cleanedUp = new AtomicBoolean(false); + + /** + * Wraps an already-written, ordered list of spill files. The list is sealed; it never grows. + */ + FetchedChannelState(List files) { + this.files = new ArrayList<>(checkNotNull(files)); + } + + // ------------------------------------------------------------------------------------------- + // Read-phase API (called by the reader after the writer is sealed) + // ------------------------------------------------------------------------------------------- + + /** + * Opens a root reader covering all segments from the beginning. The returned reader holds one + * lifecycle grant and must be closed when done. + */ + public FetchedChannelStateReader reader() { + return new FetchedChannelStateSnapshot( + this, FetchedChannelStateReaderImpl.Position.atStart()) + .reader(); + } + + /** Returns the ordered list of spill file paths. Read-only view. */ + public List files() { + return Collections.unmodifiableList(files); + } + + // ------------------------------------------------------------------------------------------- + // Lifecycle + // ------------------------------------------------------------------------------------------- + + /** Acquires a lifecycle grant for a reader or handoff owner. */ + public void acquire() { + refCount.incrementAndGet(); + } + + /** + * Releases a lifecycle grant. When the last grant is released (refCount reaches zero), all + * spill files are deleted. This preserves the invariant that files exist for the lifetime of + * all readers (drain + snapshot) and are cleaned up exactly once when the last reader finishes. + */ + public void release() throws IOException { + if (refCount.decrementAndGet() == 0) { + if (cleanedUp.compareAndSet(false, true)) { + closed = true; + deleteAllFiles(); + } + } + } + + /** Forces cleanup even when lifecycle grants are still outstanding. */ + @Override + public void close() throws IOException { + if (closed) { + return; + } + closed = true; + if (cleanedUp.compareAndSet(false, true)) { + deleteAllFiles(); + } + } + + private void deleteAllFiles() throws IOException { + IOException firstError = null; + for (Path file : files) { + try { + Files.deleteIfExists(file); + } catch (IOException e) { + if (firstError == null) { + firstError = e; + } else { + firstError.addSuppressed(e); + } + } + } + if (firstError != null) { + throw firstError; + } + } + + @VisibleForTesting + public boolean isClosed() { + return closed; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java new file mode 100644 index 0000000000000..ce58cd4cea8de --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; + +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Drains a {@link FetchedChannelState} into recovered-buffer queues and snapshots remaining + * segments when a checkpoint fires during recovery. + * + *

The drainer lock pairs channel delivery with reader-cursor advancement and also protects + * snapshot creation plus barrier insertion. Disk reads and buffer allocation stay outside that + * lock. + */ +@Internal +public final class FetchedChannelStateDrainer implements RecoveryCheckpointTrigger, Closeable { + + private final FetchedChannelStateReader rootReader; + + private final CompletableFuture resolvedChannelsFuture; + + private final Object lock = new Object(); + + /** + * Set under {@link #lock} once {@link #drain()} has consumed every segment. After that the + * {@link #rootReader} is closed by {@link #close()}, so a later {@link + * #snapshotAndInsertBarriers} must not derive from it; it returns an empty reader instead. + * Guarded by the lock so the check is atomic with barrier insertion. + */ + private boolean drainFinished; + + public FetchedChannelStateDrainer( + FetchedChannelState channelState, + CompletableFuture> channelsFuture) { + this.rootReader = checkNotNull(channelState).reader(); + this.resolvedChannelsFuture = checkNotNull(channelsFuture).thenApply(ResolvedChannels::new); + } + + private static final class ResolvedChannels { + final List allChannels; + final Map channelByInfo; + + ResolvedChannels(List all) { + this.allChannels = all; + Map byInfo = new HashMap<>(); + for (RecoverableInputChannel ch : all) { + byInfo.put(ch.getChannelInfo(), ch); + } + this.channelByInfo = byInfo; + } + } + + /** + * Drains all segments from the spill file into the corresponding recovery buffer queues. Each + * segment is split into chunks of at most {@code memorySegmentSize} bytes; a full chunk is + * delivered under the drainer lock paired with a segment commit. After all segments are + * drained, every channel's {@link RecoverableInputChannel#finishRecoveredBufferDelivery()} is + * called. + * + *

Disk reads and buffer allocations happen outside the lock; only the "deliver + commit" + * pair is locked to guarantee atomicity with snapshot. + */ + public void drain() throws IOException, InterruptedException { + ResolvedChannels channels = resolvedChannelsFuture.join(); + Optional next; + while ((next = rootReader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + RecoverableInputChannel ch = channels.channelByInfo.get(seg.channelInfo()); + if (ch == null) { + throw new IllegalStateException( + "Drain: no physical channel found for " + seg.channelInfo()); + } + drainSegment(seg, ch); + } + + // Mark drain done before rootReader is closed, so a concurrent snapshot returns empty + // rather than deriving from the soon-to-be-closed rootReader. Under the lock to stay atomic + // with snapshotAndInsertBarriers' check. + synchronized (lock) { + drainFinished = true; + } + for (RecoverableInputChannel ch : channels.allChannels) { + ch.finishRecoveredBufferDelivery(); + } + } + + /** + * Drains one segment into the given channel. Fills buffers from the segment's opaque byte + * stream in chunks of at most {@code memorySegmentSize} bytes. A full buffer is delivered under + * the lock and a fresh one is requested; a partial tail buffer (if non-empty) is also + * delivered. + */ + private void drainSegment(SpillSegment seg, RecoverableInputChannel ch) + throws IOException, InterruptedException { + InputStream in = seg.bodyStream(); + Buffer buf = ch.requestRecoveryBufferBlocking(); + int cap = buf.getMaxCapacity(); + + while (fill(buf, in, cap - buf.getSize()) > 0) { + if (buf.getSize() == cap) { + // Buffer is full: deliver under lock and request a fresh one. + synchronized (lock) { + ch.onRecoveredStateBuffer(buf); + seg.commit(); + } + buf = ch.requestRecoveryBufferBlocking(); + cap = buf.getMaxCapacity(); + } + // If buf is not full yet, the fill returned > 0 bytes but segment is not exhausted; + // loop and keep filling the same buffer. + } + + if (buf.getSize() > 0) { + // Deliver the partial tail buffer. + synchronized (lock) { + ch.onRecoveredStateBuffer(buf); + seg.commit(); + } + } else { + buf.recycleBuffer(); + } + } + + /** + * Fills up to {@code remaining} bytes from {@code in} into {@code buf}. Returns the number of + * bytes actually written; returns 0 if the stream is at EOF. Does not close or recycle {@code + * buf}; ownership stays with the caller. + */ + private static int fill(Buffer buf, InputStream in, int remaining) throws IOException { + if (remaining == 0) { + return 0; + } + // Do not use try-with-resources: ChannelStateByteBuffer.close() recycles the buffer, + // but the buffer is still owned by the caller here. + ChannelStateByteBuffer view = ChannelStateByteBuffer.wrap(buf); + return view.writeBytes(in, remaining); + } + + /** + * Atomically snapshots the undrained portion of the spill and inserts {@link + * RecoveryCheckpointBarrier}s into all in-recovery channels. Returns an independent reader over + * the remaining segments for replay into the checkpoint stream; the caller owns and must close + * it. + * + *

If the drain has already finished, the root reader is closed and there is nothing left to + * snapshot; an empty reader is returned so the caller's normal flow handles it uniformly. + */ + @Override + public FetchedChannelStateReader snapshotAndInsertBarriers(long checkpointId) + throws IOException { + ResolvedChannels channels = resolvedChannelsFuture.join(); + + // Barrier insertion and snapshot must occur within the same critical section so that the + // snapshot's committed position reflects exactly the drain position at the moment barriers + // were inserted, with no window for the drain thread to advance between. + synchronized (lock) { + for (RecoverableInputChannel ch : channels.allChannels) { + ch.insertRecoveryCheckpointBarrierIfInRecovery(checkpointId); + } + if (drainFinished) { + // Drain consumed everything and rootReader is (being) closed; nothing left to + // snapshot. Return an empty reader so the caller's normal flow handles it. + return FetchedChannelStateReader.emptyReader(); + } + return rootReader.snapshot().reader(); + } + } + + @Override + public void close() throws IOException { + rootReader.close(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReader.java new file mode 100644 index 0000000000000..5e65f7f6b79c2 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReader.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; + +import java.io.Closeable; +import java.io.InputStream; +import java.util.Collections; +import java.util.Optional; + +/** + * Forward reader over a {@link FetchedChannelState}'s spill files. This is our own segment reader, + * on purpose not a Java {@link java.util.Iterator}: our access pattern ("a body must be + * fully read before the next segment", "body ownership is handed to the consumer", "consume and + * commit are separate steps") does not fit the {@code hasNext/next} contract. + * + *

This interface is the contract callers depend on; {@link FetchedChannelStateReaderImpl} holds + * the implementation (the live file stream, the two progress positions, the bounded body view). + * + *

Reading is strictly sequential: a reader is positioned once (offset 0 for the root reader, or + * the committed position for a {@link #snapshot()}), then consumes forward only via {@link + * #nextSegment()}. It never seeks backward and never re-positions mid-iteration. + * + *

The drain thread reads the root reader front to back and commits via {@link + * SpillSegment#commit()}; each checkpoint derives a fresh {@link #snapshot()} that resumes from the + * committed position. {@link #snapshot()} and {@link SpillSegment#commit()} must be called under + * the drainer lock; disk reads happen outside it. + */ +@Internal +public interface FetchedChannelStateReader extends Closeable { + + /** + * Advances to the next segment and returns it, or {@link Optional#empty()} when no segment + * remains. Advancing and probing are one step; there is no separate {@code hasNext}. + * + *

Entry rule (the first call is exempt): the previous segment's body must be fully read, + * otherwise this is a contract violation and fails loud (no skip-ahead). + */ + Optional nextSegment(); + + /** + * Derives an independent resume point starting from the committed position. The snapshot holds + * its own {@link FetchedChannelState} lifecycle grant; the caller must open a reader from it + * via {@link FetchedChannelStateSnapshot#reader()} and close that reader when done. + * + *

Must be called under the drainer lock so that the copied position reflects the latest + * committed state. + * + * @return a snapshot capturing the current committed position; caller must open and close a + * reader from it + */ + FetchedChannelStateSnapshot snapshot(); + + /** + * Returns a reader with no segments — its first {@link #nextSegment()} is empty. Each call + * hands out a fresh instance: readers have independent lifecycle and {@link #close()} is + * single-use, so a shared instance would let one consumer's close break later consumers. Used + * wherever there is nothing to snapshot (e.g. after drain finished, or the no-op recovery + * trigger). + */ + static FetchedChannelStateReader emptyReader() { + return new FetchedChannelState(Collections.emptyList()).reader(); + } + + /** + * One per-channel segment produced by {@link #nextSegment()}. + * + *

The segment body bytes are opaque to the reader; record framing is handled by the + * consumer's deserializer. A consumer reads {@link #bodyStream()} to EOF (after {@link + * #length()} bytes), and the drain consumer additionally calls {@link #commit()} under the + * drainer lock after each delivery so that a later {@link FetchedChannelStateReader#snapshot()} + * resumes from the delivered boundary. + * + *

Ownership of {@link #bodyStream()} passes to the consumer: the reader no longer tracks how + * far it has been read. The "previous body must be fully read" rule (no skip-ahead) is enforced + * at the next {@link FetchedChannelStateReader#nextSegment()} call, not here. + * + *

A segment is valid only until the next {@code nextSegment()} call on the parent reader. + */ + interface SpillSegment { + + /** The channel whose data this segment contains. */ + InputChannelInfo channelInfo(); + + /** + * Returns an {@link InputStream} bounded to this segment's body. Reading returns {@code -1} + * (EOF) after {@link #length()} bytes; it never reads into the next segment or the next + * file. + * + *

The stream is single-use, not thread-safe, and must be fully consumed before the next + * {@link FetchedChannelStateReader#nextSegment()}. + */ + InputStream bodyStream(); + + /** + * Number of body bytes this segment hands out before EOF. For the snapshot path this is the + * not-yet-delivered remainder used as the length prefix when writing to the checkpoint + * stream. Bounded by the spill file size limit, so it always fits in an {@code int}. + */ + int length(); + + /** + * Advances the reader's committed position to match how many body bytes have been read from + * {@link #bodyStream()} so far. Must be called under the drainer lock after each buffer + * delivery so that a subsequent {@link FetchedChannelStateReader#snapshot()} sees the + * correct delivered boundary. Only the drain (root) reader commits; the snapshot reader + * never does. + */ + void commit(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderImpl.java new file mode 100644 index 0000000000000..cb8b6dca4c246 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderImpl.java @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; + +import javax.annotation.Nullable; + +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.runtime.checkpoint.channel.AbstractSpillingHandler.SEGMENT_HEADER_BYTES; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * The single {@link FetchedChannelStateReader} implementation over a {@link FetchedChannelState}'s + * spill files. + * + *

Reading is strictly sequential and never seeks mid-iteration. There is exactly one place that + * skips bytes: the very first {@link #nextSegment()} call, where a snapshot reader started mid-body + * discards the already-delivered prefix to land on the not-yet-delivered remainder. Every later + * call does no skipping at all — the previous body was read to its end, so the stream already sits + * on the next segment's header. This "skip only on first positioning" rule is what keeps the + * steady-state path free of any seek/skip. + * + *

The reader holds two {@link Position}s and nothing else duplicates them: + * + *

    + *
  • {@code current} — the live read position; its {@code readOffset} is exactly where the open + * file stream sits, advancing as the header and the consumer's body reads consume bytes (the + * latter outside the drainer lock). + *
  • {@code committed} — the delivered boundary; {@link SpillSegment#commit()} advances it from + * {@code current} (under the drainer lock). {@link #snapshot()} derives a new reader from it. + *
+ * + *

The "previous body fully read before advancing" rule is checked at the {@link #nextSegment()} + * entry (the first call is exempt — there is no previous segment). Body ownership is handed to the + * consumer, so the reader does not track body progress except through {@code current}. + */ +@Internal +final class FetchedChannelStateReaderImpl implements FetchedChannelStateReader { + + private final FetchedChannelStateSnapshot snapshot; + private final FetchedChannelState channelState; + private final List files; + + /** Live read position; {@code readOffset} is where the open stream physically sits. */ + private final Position current; + + /** Delivered boundary; {@link SpillSegment#commit()} advances it from {@link #current}. */ + private final Position committed; + + /** Open stream over {@code current.fileIndex}, or {@code null} before the first read. */ + @Nullable private InputStream fileStream; + + /** Size of the file currently open. */ + private long currentFileSize; + + /** Body view of the segment returned by the last {@link #nextSegment()}, or {@code null}. */ + @Nullable private BoundedSegmentStream currentBody; + + private boolean positioned; + private boolean closed; + + FetchedChannelStateReaderImpl(FetchedChannelStateSnapshot snapshot) { + this.snapshot = snapshot; + this.channelState = snapshot.channelState(); + this.files = channelState.files(); + // Must copy the position so that this reader's commits do not mutate the snapshot's state. + this.committed = snapshot.position().copy(); + this.current = committed.copy(); + } + + @Override + public Optional nextSegment() { + checkState(!closed, "FetchedChannelStateReader is closed"); + checkState( + currentBody == null || currentBody.remaining() == 0, + "Previous segment body not fully consumed before advancing: %s bytes left", + currentBody == null ? 0 : currentBody.remaining()); + try { + if (!positioned) { + positioned = true; + return firstSegment(); + } + return followingSegment(); + } catch (IOException e) { + throw new RuntimeException("Failed to read segment", e); + } + } + + /** + * First positioning — the only path that may skip bytes. Opens the file at the committed header + * offset and reads the header. A snapshot may resume in the middle of a segment: the committed + * {@code readOffset} says how many body bytes were already delivered, and that prefix is + * skipped so the returned body starts at the not-yet-delivered remainder. If the segment was + * already fully delivered (prefix == whole body), it is exhausted here and we move on to the + * next one. + */ + private Optional firstSegment() throws IOException { + // The committed read offset may sit mid-body (after a partial commit), but the header lives + // at segmentStartOffset. Capture how much was already delivered, then rewind the live read + // offset to the header so we open the file there and read the header, not mid-body. + int deliveredPrefix = (int) current.deliveredBodyBytes(); + current.rewindToSegmentStart(); + + if (!openCurrentFile()) { + return Optional.empty(); + } + SegmentHeader header = readHeaderAtCurrent(); + checkState( + deliveredPrefix <= header.bufferLength, + "Delivered offset %s exceeds segment length %s", + deliveredPrefix, + header.bufferLength); + + if (deliveredPrefix == header.bufferLength) { + // This segment was already fully delivered before the snapshot; nothing remains in it. + // Skip its whole body to reach the next segment's header, then take the steady path. + skipBody(header.bufferLength); + return followingSegment(); + } + + // Discard the already-delivered prefix (the one and only skip in this class), then hand out + // the remainder. alreadyDelivered is carried so commit() records the boundary from the + // head. + skipBody(deliveredPrefix); + currentBody = + new BoundedSegmentStream(header.bufferLength - deliveredPrefix, deliveredPrefix); + return Optional.of(new Segment(header.channelInfo, currentBody)); + } + + /** + * Steady-state path — no skipping. The previous body was read to its end, so the stream sits + * exactly on this segment's header (or at the current file's end, in which case we roll to the + * next file). Reads the header and returns the whole-body view. + */ + private Optional followingSegment() throws IOException { + if (!openCurrentFile()) { + return Optional.empty(); + } + SegmentHeader header = readHeaderAtCurrent(); + currentBody = new BoundedSegmentStream(header.bufferLength); + return Optional.of(new Segment(header.channelInfo, currentBody)); + } + + @Override + public FetchedChannelStateSnapshot snapshot() { + checkState(!closed, "FetchedChannelStateReader is closed"); + return new FetchedChannelStateSnapshot(channelState, committed.copy()); + } + + @Override + public void close() throws IOException { + if (closed) { + return; + } + closed = true; + try { + closeFileStream(); + } finally { + snapshot.release(); + } + } + + // ------------------------------------------------------------------------------------------- + // Sequential IO over the spill files; all of it advances current.readOffset / current.fileIndex + // ------------------------------------------------------------------------------------------- + + /** + * Ensures a file is open with the stream positioned at {@code current}'s read offset, ready to + * read this segment's header. Rolls to the next file when the current one is exhausted. Returns + * false when no segment remains. + * + *

{@code current.segmentStartOffset} is set to where the header begins, so a later {@link + * SpillSegment#commit()} records the right segment for the snapshot to resume from. + */ + private boolean openCurrentFile() throws IOException { + if (current.fileIndex >= files.size()) { + return false; + } + openFileAndSeek(); + if (current.readOffset < currentFileSize) { + current.startSegmentHere(); + return true; + } + // Current file fully read: move to the next file's first segment. + closeFileStream(); + current.rollToNextFile(); + if (current.fileIndex >= files.size()) { + return false; + } + openFileAndSeek(); + if (current.readOffset < currentFileSize) { + current.startSegmentHere(); + return true; + } + return false; + } + + /** Reads the 12-byte header at the current read offset; advances past it. */ + private SegmentHeader readHeaderAtCurrent() throws IOException { + byte[] headerBytes = new byte[SEGMENT_HEADER_BYTES]; + readFully(headerBytes); + DataInputStream h = new DataInputStream(new ByteArrayInputStream(headerBytes)); + int gateIdx = h.readInt(); + int channelIdx = h.readInt(); + int bufferLength = h.readInt(); + checkState(bufferLength >= 0, "negative segment length: %s", bufferLength); + return new SegmentHeader(new InputChannelInfo(gateIdx, channelIdx), bufferLength); + } + + /** + * Ensures the file at {@code current.fileIndex} is open with the stream positioned at {@code + * current.readOffset}. If a stream is already open it is left as-is: sequential reading + * guarantees it is already there. + */ + private void openFileAndSeek() throws IOException { + if (fileStream != null) { + return; + } + Path path = files.get(current.fileIndex); + currentFileSize = Files.size(path); + InputStream in = Files.newInputStream(path); + try { + skipOnStream(in, current.readOffset, path); + } catch (IOException e) { + in.close(); + throw e; + } + fileStream = in; + } + + /** Skips {@code count} body bytes on the open stream, advancing the read offset. */ + private void skipBody(long count) throws IOException { + if (count > 0) { + skipOnStream(fileStream, count, files.get(current.fileIndex)); + current.advanceReadOffset(count); + } + } + + /** Skips exactly {@code count} bytes on {@code in}, failing loud if the file ends early. */ + private void skipOnStream(InputStream in, long count, Path path) throws IOException { + long skipped = 0; + while (skipped < count) { + long s = in.skip(count - skipped); + if (s <= 0) { + // skip can return 0 near EOF; read-and-discard as a fallback. + if (in.read() < 0) { + throw new EOFException( + "Cannot position to offset " + count + " in spill file " + path); + } + skipped++; + } else { + skipped += s; + } + } + } + + private void readFully(byte[] buf) throws IOException { + int read = 0; + while (read < buf.length) { + int n = fileStream.read(buf, read, buf.length - read); + if (n < 0) { + throw new EOFException( + "Truncated segment header in file " + + files.get(current.fileIndex) + + " at offset " + + current.readOffset + + ": expected " + + buf.length + + " bytes, got " + + read); + } + read += n; + current.advanceReadOffset(n); + } + } + + /** Reads up to {@code len} body bytes from the current file; called only by the body view. */ + private int readBody(byte[] buf, int off, int len) throws IOException { + int n = fileStream.read(buf, off, len); + if (n > 0) { + current.advanceReadOffset(n); + } + return n; + } + + private void closeFileStream() throws IOException { + if (fileStream != null) { + fileStream.close(); + fileStream = null; + } + } + + // ------------------------------------------------------------------------------------------- + // Position: the three progress values locating where a snapshot resumes + // ------------------------------------------------------------------------------------------- + + /** + * One progress point, expressed with the three values that fully locate where a snapshot + * resumes. The reader holds two of these: {@code current} (live read position) and {@code + * committed} (delivered boundary). They are the same shape but different in meaning; neither + * keeps a shadow copy of the other. + * + *

    + *
  • {@code fileIndex} — file holding the current segment. + *
  • {@code segmentStartOffset} — byte offset of the current segment's header. A snapshot + * reads the segment metadata (channel, full body length) from here, because the segment + * may already be partially delivered yet the snapshot still needs the whole-segment + * header. + *
  • {@code readOffset} — for {@code current}, exactly where the open stream sits (the live + * read offset); for {@code committed}, the delivered boundary. For a brand-new segment + * {@code readOffset == segmentStartOffset} (header not yet passed); after the header and + * n body bytes it is {@code segmentStartOffset + SEGMENT_HEADER_BYTES + n}. + *
+ */ + static final class Position { + private int fileIndex; + private long segmentStartOffset; + private long readOffset; + + Position(int fileIndex, long segmentStartOffset, long readOffset) { + this.fileIndex = fileIndex; + this.segmentStartOffset = segmentStartOffset; + this.readOffset = readOffset; + } + + static Position atStart() { + return new Position(0, 0L, 0L); + } + + Position copy() { + return new Position(fileIndex, segmentStartOffset, readOffset); + } + + /** + * Already-delivered body bytes of the current segment, clamped to 0 before the header is + * crossed (where {@code readOffset <= segmentStartOffset}). Only the first-positioning path + * uses it, to size the prefix a snapshot must discard. + */ + long deliveredBodyBytes() { + return Math.max(0L, readOffset - segmentStartOffset - SEGMENT_HEADER_BYTES); + } + + /** Advances the live read offset by {@code delta} bytes just read/skipped. */ + void advanceReadOffset(long delta) { + readOffset += delta; + } + + /** + * Rewinds the live read offset back to the current segment's header. Used only on first + * positioning: a snapshot's committed offset may sit mid-body, but the header must be read + * first, so the read offset returns to {@code segmentStartOffset} before opening the file. + */ + void rewindToSegmentStart() { + readOffset = segmentStartOffset; + } + + /** + * Marks the current read offset as the start of the segment about to be read (its header + * begins here). Called right before reading a header. + */ + void startSegmentHere() { + segmentStartOffset = readOffset; + } + + /** Rolls to the start of the next file once the current one is exhausted. */ + void rollToNextFile() { + fileIndex++; + segmentStartOffset = 0L; + readOffset = 0L; + } + + /** + * Copies this position into {@code target} but pins {@code target}'s read offset to {@code + * deliveredBody} body bytes of the current segment — i.e. {@code commit} records "delivered + * up to here", not "physically read up to here". + */ + void copyAsDelivered(Position target, long deliveredBody) { + target.fileIndex = fileIndex; + target.segmentStartOffset = segmentStartOffset; + target.readOffset = segmentStartOffset + SEGMENT_HEADER_BYTES + deliveredBody; + } + } + + /** Parsed segment header: channel and full body length. */ + private static final class SegmentHeader { + private final InputChannelInfo channelInfo; + private final int bufferLength; + + private SegmentHeader(InputChannelInfo channelInfo, int bufferLength) { + this.channelInfo = channelInfo; + this.bufferLength = bufferLength; + } + } + + /** + * The single {@link SpillSegment} implementation. Exposes one segment's channel, body, and + * length; {@link #commit()} advances the reader's {@code committed} position to however many + * body bytes have been read. Reading the body and committing are separate steps so the consumer + * can read outside the drainer lock and commit inside it. + * + *

Only the root (drain) reader commits. + */ + private final class Segment implements SpillSegment { + private final InputChannelInfo channelInfo; + private final BoundedSegmentStream body; + + private Segment(InputChannelInfo channelInfo, BoundedSegmentStream body) { + this.channelInfo = channelInfo; + this.body = body; + } + + @Override + public InputChannelInfo channelInfo() { + return channelInfo; + } + + @Override + public InputStream bodyStream() { + return body; + } + + @Override + public int length() { + return body.deliverableLength(); + } + + @Override + public void commit() { + current.copyAsDelivered(committed, body.deliveredFromSegmentHead()); + } + } + + /** + * A forward-only, bounded view over one segment's not-yet-delivered body remainder. It hands + * out {@code remainingLength} bytes and reaches EOF after them; it never reads into the next + * segment or file. The underlying stream is already positioned at the first byte this view + * hands out (the prefix, if any, was skipped before construction). If the file ends before the + * segment end, an {@link EOFException} is thrown (fail-loud). Closing this view does not close + * the underlying file; the reader owns it. + */ + private final class BoundedSegmentStream extends InputStream { + + /** Body bytes already delivered before this view (the skipped prefix); 0 for the root. */ + private final int alreadyDelivered; + + private final int remainingLength; + private int read; + + private BoundedSegmentStream(int remainingLength) { + this(remainingLength, 0); + } + + private BoundedSegmentStream(int remainingLength, int alreadyDelivered) { + this.remainingLength = remainingLength; + this.alreadyDelivered = alreadyDelivered; + } + + /** Body bytes not yet handed out (between the read position and the segment end). */ + int remaining() { + return remainingLength - read; + } + + /** Number of body bytes this view will hand out: the not-yet-delivered remainder. */ + int deliverableLength() { + return remainingLength; + } + + /** + * Total delivered body bytes measured from the segment head: the skipped prefix plus what + * has been read through this view. This is what {@link Segment#commit()} records. + */ + int deliveredFromSegmentHead() { + return alreadyDelivered + read; + } + + @Override + public int read() throws IOException { + byte[] one = new byte[1]; + int n = read(one, 0, 1); + return n < 0 ? -1 : (one[0] & 0xFF); + } + + @Override + public int read(byte[] buf, int off, int len) throws IOException { + if (read >= remainingLength) { + return -1; + } + int toRead = Math.min(len, remainingLength - read); + int n = readBody(buf, off, toRead); + if (n < 0) { + throw new EOFException( + "Unexpected EOF in segment body after " + + read + + "/" + + remainingLength + + " bytes"); + } + read += n; + return n; + } + + @Override + public void close() { + // Do not close the underlying file; it is owned by the reader. + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateSnapshot.java new file mode 100644 index 0000000000000..cb54006ece40d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateSnapshot.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import java.io.IOException; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * An immutable resume point for a {@link FetchedChannelState} reader. It captures the {@link + * FetchedChannelStateReaderImpl.Position} at which reading should start and holds one lifecycle + * grant on the underlying {@link FetchedChannelState} (acquired in the constructor). + * + *

A snapshot is a one-shot, single-reader handle: + * + *

    + *
  • Exactly one {@link FetchedChannelStateReader} may be opened from it via {@link #reader()}. + *
  • When that reader is closed, it releases the lifecycle grant via {@link #release()}. + *
+ * + *

The 1:1 reader constraint is enforced fail-loud: calling {@link #reader()} a second time + * throws {@link IllegalStateException} immediately. + */ +final class FetchedChannelStateSnapshot { + + private final FetchedChannelState channelState; + private final FetchedChannelStateReaderImpl.Position position; + + /** True once {@link #reader()} has been called; prevents opening a second reader. */ + private boolean readerOpened; + + /** + * Creates a snapshot at {@code position} within {@code channelState}. Acquires one lifecycle + * grant on {@code channelState}; the grant is released when the reader returned by {@link + * #reader()} is closed. + */ + FetchedChannelStateSnapshot( + FetchedChannelState channelState, FetchedChannelStateReaderImpl.Position position) { + this.channelState = channelState; + this.position = position; + channelState.acquire(); + } + + /** + * Opens the reader for this snapshot. May be called at most once; a second call fails loud to + * enforce the 1:1 snapshot-to-reader invariant. + * + * @return a new reader starting from this snapshot's position; caller must close it when done + */ + FetchedChannelStateReader reader() { + checkState(!readerOpened, "A reader has already been opened from this snapshot"); + readerOpened = true; + return new FetchedChannelStateReaderImpl(this); + } + + /** + * Releases the lifecycle grant held by this snapshot. Called by the reader on close; must not + * be called directly by any other party. + */ + void release() throws IOException { + channelState.release(); + } + + /** Returns the underlying channel state (package-private; used by the reader). */ + FetchedChannelState channelState() { + return channelState; + } + + /** + * Returns the start position (package-private; used by the reader, which copies it + * immediately). + */ + FetchedChannelStateReaderImpl.Position position() { + return position; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java index 547b60ef93aee..7fea56b5ace06 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java @@ -23,6 +23,7 @@ import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import java.io.IOException; +import java.util.Optional; /** Reads channel state saved during checkpoint/savepoint. */ @Internal @@ -40,6 +41,9 @@ void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) throws IOException, InterruptedException; + /** Returns the {@link FetchedChannelState} produced by {@link #readInputData}, if any. */ + Optional getProducedChannelState(); + @Override void close() throws Exception; @@ -54,6 +58,11 @@ public void readInputData( public void readOutputData( ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) {} + @Override + public Optional getProducedChannelState() { + return Optional.empty(); + } + @Override public void close() {} }; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java index 204f3d3c074b6..137334ae14fc3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java @@ -98,13 +98,13 @@ public static void traceRecover( public static void tracePersist( String action, Buffer buffer, Object channelInfo, long checkpointId) { + tracePersist(action, buffer.toDebugString(INCLUDE_HASH), channelInfo, checkpointId); + } + + public static void tracePersist( + String action, Object persisted, Object channelInfo, long checkpointId) { if (LOG.isTraceEnabled()) { - LOG.trace( - "{} {}, checkpoint {} @ {}", - action, - buffer.toDebugString(INCLUDE_HASH), - checkpointId, - channelInfo); + LOG.trace("{} {}, checkpoint {} @ {}", action, persisted, checkpointId, channelInfo); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EndOfFetchedChannelStateEvent.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EndOfFetchedChannelStateEvent.java new file mode 100644 index 0000000000000..9900a0cdc660d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/EndOfFetchedChannelStateEvent.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.event.RuntimeEvent; + +/** + * Marks the tail of recovered buffers that the spill drain pushed into a {@link + * RecoverableInputChannel}. The consume path polls this sentinel to learn the exact moment all + * recovered buffers have been consumed; it is never delivered to the operator. It is distinct from + * {@link EndOfInputChannelStateEvent} (which terminates the {@link RecoveredInputChannel} read + * stream) so the two recovery handoffs cannot be confused. + */ +public class EndOfFetchedChannelStateEvent extends RuntimeEvent { + + /** The singleton instance of this event. */ + public static final EndOfFetchedChannelStateEvent INSTANCE = + new EndOfFetchedChannelStateEvent(); + + // ------------------------------------------------------------------------ + + // not instantiable + private EndOfFetchedChannelStateEvent() {} + + // ------------------------------------------------------------------------ + + @Override + public void write(DataOutputView out) { + throw new UnsupportedOperationException( + "EndOfFetchedChannelStateEvent must be serialized via EventSerializer's dedicated" + + " type-tag path, not reflective write()."); + } + + @Override + public void read(DataInputView in) { + throw new UnsupportedOperationException( + "EndOfFetchedChannelStateEvent must be deserialized via EventSerializer's dedicated" + + " type-tag path, not reflective read()."); + } + + // ------------------------------------------------------------------------ + + @Override + public int hashCode() { + return 20250814; + } + + @Override + public boolean equals(Object obj) { + return obj != null && obj.getClass() == EndOfFetchedChannelStateEvent.class; + } + + @Override + public String toString() { + return getClass().getSimpleName(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java index d9b7885815bd1..82ae84a6c91f2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java @@ -19,8 +19,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; @@ -29,13 +27,10 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; -import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; import org.apache.flink.runtime.io.network.partition.ChannelStateHolder; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; -import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -62,13 +57,6 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan private final CompletableFuture stateConsumedFuture = new CompletableFuture<>(); protected final BufferManager bufferManager; - /** - * Future that completes when recovered buffers have been filtered for this channel. This - * completes before stateConsumedFuture, enabling earlier RUNNING state transition when - * unaligned checkpoint during recovery is enabled. - */ - private final CompletableFuture bufferFilteringCompleteFuture = new CompletableFuture<>(); - @GuardedBy("receivedBuffers") private boolean isReleased; @@ -105,7 +93,7 @@ public abstract class RecoveredInputChannel extends InputChannel implements Chan numBytesIn, numBuffersIn); - bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0); + bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0, true); this.networkBuffersPerChannel = networkBuffersPerChannel; } @@ -115,23 +103,19 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { this.channelStateWriter = checkNotNull(channelStateWriter); } - public final InputChannel toInputChannel() throws IOException { - Preconditions.checkState( - bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); + public final InputChannel toInputChannel(boolean needsRecovery) throws IOException { if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { Preconditions.checkState( stateConsumedFuture.isDone(), "recovered state is not fully consumed"); } - // Extract remaining buffers before conversion. - // These buffers have been filtered but not yet consumed by the Task. - final ArrayDeque remainingBuffers; + // With checkpointing-during-recovery, data is spilled instead of queued here. synchronized (receivedBuffers) { - remainingBuffers = new ArrayDeque<>(receivedBuffers); - receivedBuffers.clear(); + Preconditions.checkState(receivedBuffers.isEmpty(), "Received buffer should be empty."); } - final InputChannel inputChannel = toInputChannelInternal(remainingBuffers); + final InputChannel inputChannel = toInputChannelInternal(needsRecovery); + inputChannel.setup(); inputChannel.checkpointStopped(lastStoppedCheckpointId); return inputChannel; } @@ -142,23 +126,12 @@ public void checkpointStopped(long checkpointId) { } /** - * Creates the physical InputChannel from this recovered channel. - * - * @param remainingBuffers buffers that have been filtered but not yet consumed by the Task. - * These buffers will be migrated to the new physical channel. - * @return the physical InputChannel (LocalInputChannel or RemoteInputChannel) + * Creates the physical {@link InputChannel}; {@code needsRecovery} controls whether it starts + * in recovery. */ - protected abstract InputChannel toInputChannelInternal(ArrayDeque remainingBuffers) + protected abstract InputChannel toInputChannelInternal(boolean needsRecovery) throws IOException; - /** - * Returns the future that completes when buffer filtering is complete. This future completes - * before stateConsumedFuture, at the point when finishReadRecoveredState() is called. - */ - CompletableFuture getBufferFilteringCompleteFuture() { - return bufferFilteringCompleteFuture; - } - CompletableFuture getStateConsumedFuture() { return stateConsumedFuture; } @@ -166,10 +139,7 @@ CompletableFuture getStateConsumedFuture() { public void onRecoveredStateBuffer(Buffer buffer) { boolean recycleBuffer = true; NetworkActionsLogger.traceRecover( - "InputChannelRecoveredStateHandler#recover", - buffer, - inputGate.getOwningTaskName(), - channelInfo); + "NoSpillingHandler#recover", buffer, inputGate.getOwningTaskName(), channelInfo); try { final boolean wasEmpty; synchronized (receivedBuffers) { @@ -195,21 +165,14 @@ public void onRecoveredStateBuffer(Buffer buffer) { } public void finishReadRecoveredState() throws IOException { - // Adding the event and completing the future must be atomic under receivedBuffers lock. - // Without this, either ordering has a race: - // - event first: task thread consumes EndOfInputChannelStateEvent, which completes - // stateConsumedFuture. When checkpointing during recovery is disabled, - // stateConsumedFuture triggers requestPartitions -> toInputChannel(), which - // fails because bufferFilteringCompleteFuture is not yet done. - // - future first: toInputChannel() extracts buffers before the event is added, - // losing the EndOfInputChannelStateEvent. - // Both toInputChannel() and getNextRecoveredStateBuffer() synchronize on - // receivedBuffers, so holding the same lock here guarantees - // bufferFilteringCompleteFuture is always done before stateConsumedFuture. + // In legacy recovery, adding the sentinel must be atomic under receivedBuffers lock to + // ensure the sentinel is enqueued before any concurrent reader can observe an empty queue + // and miss the EndOfInputChannelStateEvent that completes stateConsumedFuture. synchronized (receivedBuffers) { - onRecoveredStateBuffer( - EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); - bufferFilteringCompleteFuture.complete(null); + if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { + onRecoveredStateBuffer( + EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); + } } bufferManager.releaseFloatingBuffers(); LOG.debug("{}/{} finished recovering input.", inputGate.getOwningTaskName(), channelInfo); @@ -229,8 +192,6 @@ private BufferAndAvailability getNextRecoveredStateBuffer() throws IOException { if (next == null) { return null; } else if (isEndOfInputChannelStateEvent(next)) { - Preconditions.checkState( - bufferFilteringCompleteFuture.isDone(), "buffer filtering is not complete"); stateConsumedFuture.complete(null); return null; } else { @@ -307,7 +268,7 @@ boolean isReleased() { } } - void releaseAllResources() throws IOException { + public void releaseAllResources() throws IOException { ArrayDeque releasedBuffers = new ArrayDeque<>(); boolean shouldRelease = false; @@ -338,26 +299,7 @@ public Buffer requestBufferBlocking() throws InterruptedException, IOException { bufferManager.requestExclusiveBuffers(networkBuffersPerChannel); exclusiveBuffersAssigned = true; } - if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { - // When checkpoint-during-recovery is not enabled, the original blocking allocation - // is used as-is — no heap buffer fallback, no behavior change from the legacy path. - return bufferManager.requestBufferBlocking(); - } - // Use heap buffer fallback to avoid deadlock during filtering recovery: the filtering - // thread first requests buffers to read state (pre-filter), then requests more buffers - // to write filtered output (post-filter). If pre-filter buffers exhaust the pool, - // post-filter allocation blocks, stalling the thread so pre-filter buffers can never - // be consumed and released — the thread deadlocks itself. Heap buffers bypass the pool - // so post-filter writes always proceed. Both call sites (getBuffer and filterAndRewrite) - // go through this method, so the fallback applies uniformly. - // TODO: replace heap fallback with disk spilling to bound memory usage in FLINK-38544. - Buffer buffer = bufferManager.requestBuffer(); - if (buffer != null) { - return buffer; - } - MemorySegment memorySegment = - MemorySegmentFactory.allocateUnpooledSegment(MemoryManager.DEFAULT_PAGE_SIZE); - return new NetworkBuffer(memorySegment, FreeingBufferRecycler.INSTANCE); + return bufferManager.requestBufferBlocking(); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java index a3743edec7f94..26f0b5f8dfff0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java @@ -326,8 +326,7 @@ public void close() throws IOException { // terminate the TM Exception err = null; for (InputChannelInfo channelInfo : new ArrayList<>(recordDeserializers.keySet())) { - final boolean hadError = - checkpointedInputGate.getChannel(channelInfo.getInputChannelIdx()).hasError(); + final boolean hadError = checkpointedInputGate.getChannel(channelInfo).hasError(); try { releaseDeserializer(channelInfo); } catch (Exception e) { diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java index ad0dcb5b0bb74..26f335386a198 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java @@ -29,9 +29,11 @@ import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.EventAnnouncement; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.EndOfFetchedChannelStateEvent; import org.apache.flink.runtime.io.network.partition.consumer.EndOfOutputChannelStateEvent; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; import org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput; import org.slf4j.Logger; @@ -202,6 +204,15 @@ private Optional handleEvent(BufferOrEvent bufferOrEvent) throws bufferOrEvent.getChannelInfo()); } else if (bufferOrEvent.getEvent().getClass() == EndOfOutputChannelStateEvent.class) { upstreamRecoveryTracker.handleEndOfRecovery(bufferOrEvent.getChannelInfo()); + } else if (eventClass == EndOfFetchedChannelStateEvent.class) { + // Tail of the recovered buffers: only a RecoverableInputChannel can produce this + // sentinel, so anything else here is a bug rather than something to tolerate. + InputChannel channel = inputGate.getChannel(bufferOrEvent.getChannelInfo()); + checkState( + channel instanceof RecoverableInputChannel, + "EndOfFetchedChannelStateEvent received on a non-recoverable channel %s", + bufferOrEvent.getChannelInfo()); + ((RecoverableInputChannel) channel).onRecoveredStateConsumed(); } return Optional.of(bufferOrEvent); } @@ -296,6 +307,10 @@ public InputChannel getChannel(int channelIndex) { return inputGate.getChannel(channelIndex); } + public InputChannel getChannel(InputChannelInfo channelInfo) { + return inputGate.getChannel(channelInfo); + } + public List getChannelInfos() { return inputGate.getChannelInfos(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java index e207eb6213edc..165e66a278d9d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java @@ -91,7 +91,7 @@ public int getNumberOfChannels() { /** Maximum parallelism for configuring partitioners. */ private final int maxParallelism; - /** Temporary directories for spilling spanning records. Can be empty but never null. */ + /** Temporary directories for spilling spanning records. Never null or empty. */ private final String[] tmpDirectories; /** Whether unaligned checkpointing during recovery is enabled. */ @@ -99,8 +99,8 @@ public int getNumberOfChannels() { /** * Network buffer memory segment size in bytes (from taskmanager.memory.segment-size). Used to - * size the reusable heap source buffer in {@code InputChannelRecoveredStateHandler} so it - * matches the network buffer size. + * size the reusable heap source buffer in {@code SpillingWithFilteringHandler} so it matches + * the network buffer size. */ private final int memorySegmentSize; @@ -112,8 +112,9 @@ public int getNumberOfChannels() { * @param rescalingDescriptor Descriptor containing rescaling information. Not null. * @param subtaskIndex Current subtask index. * @param maxParallelism Maximum parallelism. - * @param tmpDirectories Temporary directories for spilling spanning records. Can be null - * (converted to empty array). + * @param tmpDirectories Temporary directories for spilling spanning records. Not null or empty; + * sourced from {@code IOManager.getSpillingDirectoriesPaths()}, which is guaranteed + * non-empty by the task manager I/O configuration. * @param checkpointingDuringRecoveryEnabled Whether unaligned checkpointing during recovery is * enabled. * @param memorySegmentSize Network buffer memory segment size in bytes. Must be positive. @@ -130,7 +131,8 @@ public RecordFilterContext( this.rescalingDescriptor = checkNotNull(rescalingDescriptor); this.subtaskIndex = subtaskIndex; this.maxParallelism = maxParallelism; - this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new String[0]; + checkArgument(checkNotNull(tmpDirectories).length > 0, "tmpDirectories must not be empty"); + this.tmpDirectories = tmpDirectories.clone(); this.checkpointingDuringRecoveryEnabled = checkpointingDuringRecoveryEnabled; checkArgument( memorySegmentSize > 0, "memorySegmentSize must be positive: %s", memorySegmentSize); @@ -202,7 +204,7 @@ public int getMaxParallelism() { /** * Gets the temporary directories for spilling spanning records. * - * @return The temporary directories, never null (may be empty array). + * @return The temporary directories, never null or empty. */ public String[] getTmpDirectories() { return tmpDirectories; @@ -235,15 +237,19 @@ public boolean isAmbiguous(int gateIndex, int oldSubtaskIndex) { *

The returned context has empty inputConfigs and enabled=false, so {@link * #isCheckpointingDuringRecoveryEnabled()} will always return false. * + * @param tmpDirectories Temporary directories for spilling spanning records. Not null or empty; + * callers pass the real I/O manager spilling directories so the resulting context satisfies + * the same invariant as an enabled context, regardless of which downstream path consumes + * it. * @return A disabled RecordFilterContext. */ - public static RecordFilterContext disabled() { + public static RecordFilterContext disabled(String[] tmpDirectories) { return new RecordFilterContext( new InputFilterConfig[0], InflightDataRescalingDescriptor.NO_RESCALE, 0, 0, - new String[0], + tmpDirectories, false, MemoryManager.DEFAULT_PAGE_SIZE); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index e06164125fd41..7c30746ba4bf3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -43,7 +43,10 @@ import org.apache.flink.runtime.checkpoint.SnapshotType; import org.apache.flink.runtime.checkpoint.SubTaskInitializationMetricsBuilder; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelState; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateDrainer; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.checkpoint.channel.SequentialChannelStateReader; import org.apache.flink.runtime.checkpoint.filemerging.FileMergingSnapshotManager; import org.apache.flink.runtime.execution.CancelTaskException; @@ -62,7 +65,9 @@ import org.apache.flink.runtime.io.network.partition.ChannelStateHolder; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.CheckpointableTask; import org.apache.flink.runtime.jobgraph.tasks.CoordinatedTask; @@ -304,6 +309,17 @@ public abstract class StreamTask> /** TODO it might be replaced by the global IO executor on TaskManager level future. */ private final ExecutorService channelIOExecutor; + /** + * Completed (on the {@code channelIOExecutor}) once recovery setup finishes, carrying the + * resolved checkpoint trigger: the spill drainer when recovery carries channel state, otherwise + * {@link RecoveryCheckpointTrigger#NO_OP}. Two consumers ride on this single completion: the + * barrier handler, built before the drainer exists, holds the future and reads the trigger + * lazily via {@code getNow} once a checkpoint fires; and gate conversion waits on its + * completion to run {@code requestPartitions()} (buffer filtering is done by then). + */ + private final CompletableFuture recoverySetupCompleteFuture = + new CompletableFuture<>(); + // ======================================================== // Final checkpoint / savepoint // ======================================================== @@ -885,44 +901,36 @@ private CompletableFuture restoreStateAndGates( CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); // Must set the flag on input gates BEFORE starting the async read task, because - // finishReadRecoveredState() checks this flag to complete bufferFilteringCompleteFuture. + // finishReadRecoveredState() reads this flag to decide whether to enqueue the legacy + // end-of-state sentinel. for (IndexedInputGate inputGate : inputGates) { inputGate.setCheckpointingDuringRecoveryEnabled(checkpointingDuringRecoveryEnabled); } + final CompletableFuture> physicalChannelsFuture = + new CompletableFuture<>(); + channelIOExecutor.execute( - () -> { - try { - reader.readInputData(inputGates, createRecordFilterContext()); - } catch (Exception e) { - asyncExceptionHandler.handleAsyncException( - "Unable to read channel state", e); - } - }); + () -> + recoverChannelState( + reader, + inputGates, + checkpointingDuringRecoveryEnabled, + physicalChannelsFuture)); // We wait for all input channel state to recover before we go into RUNNING state, and thus // start checkpointing. If we implement incremental checkpointing of input channel state // we must make sure it supports CheckpointType#FULL_CHECKPOINT. - List> recoveredFutures = new ArrayList<>(inputGates.length); - for (InputGate inputGate : inputGates) { - CompletableFuture requestPartitionsTrigger = - checkpointingDuringRecoveryEnabled - ? inputGate.getBufferFilteringCompleteFuture() - : inputGate.getStateConsumedFuture(); - - recoveredFutures.add(requestPartitionsTrigger); - - requestPartitionsTrigger.thenRun( - () -> - mainMailboxExecutor.execute( - inputGate::requestPartitions, "Input gate request partitions")); - } + List> recoveredFutures = + checkpointingDuringRecoveryEnabled + ? wireGateConversionWithCheckpointing(inputGates, physicalChannelsFuture) + : wireGateConversion(inputGates); // Return allOf future instead of thenRun future. thenRun() returns a NEW future that // completes only after the callback finishes. CompletableFuture executes thenRun callbacks // synchronously on the thread that calls complete(). When recoveredFutures contains - // bufferFilteringCompleteFuture (checkpointingDuringRecovery enabled), complete() is called - // on channelIOExecutor (in finishReadRecoveredState), so thenRun(suspend) also runs on + // recoverySetupCompleteFuture (checkpointingDuringRecovery enabled), complete() is called + // on channelIOExecutor (in recoverChannelState), so thenRun(suspend) also runs on // channelIOExecutor. suspend() sends a poison mail, and the mailbox thread can pick it up // and exit runMailboxLoop() before the thenRun future completes — causing // checkState(isDone) to fail. With stateConsumedFuture (the default), complete() runs on @@ -935,6 +943,183 @@ private CompletableFuture restoreStateAndGates( return allRecoveredFuture; } + /** + * Runs on the {@code channelIOExecutor}: reads input channel state, wires the spill drainer + * when checkpointing-during-recovery is enabled, and drains recovered buffers into the physical + * channels. Setup failures complete {@code physicalChannelsFuture} exceptionally so the + * recovery mailbox loop stops waiting; the drain phase keeps its own handler because the future + * is already completed by then. + */ + private void recoverChannelState( + SequentialChannelStateReader reader, + IndexedInputGate[] inputGates, + boolean checkpointingDuringRecoveryEnabled, + @Nullable CompletableFuture> physicalChannelsFuture) { + FetchedChannelStateDrainer drainer = null; + try { + reader.readInputData(inputGates, createRecordFilterContext()); + + if (checkpointingDuringRecoveryEnabled) { + Optional producedChannelState = + reader.getProducedChannelState(); + boolean needsRecovery = producedChannelState.isPresent(); + for (IndexedInputGate gate : inputGates) { + gate.setNeedsRecovery(needsRecovery); + } + if (needsRecovery) { + FetchedChannelState channelState = producedChannelState.get(); + drainer = new FetchedChannelStateDrainer(channelState, physicalChannelsFuture); + channelState.release(); + } + } + + for (IndexedInputGate gate : inputGates) { + gate.finishReadRecoveredState(); + } + // Recovery setup is done: resolve the trigger for the barrier handler (the drainer when + // recovery carries channel state, NO_OP otherwise) and, by the same completion, release + // gate conversion. Completed before any checkpoint can fire during recovery, so the + // handler reads it via getNow. + recoverySetupCompleteFuture.complete( + drainer != null ? drainer : RecoveryCheckpointTrigger.NO_OP); + } catch (Throwable t) { + asyncExceptionHandler.handleAsyncException( + "Unable to set up recovered channel state", t); + recoverySetupCompleteFuture.completeExceptionally(t); + if (checkpointingDuringRecoveryEnabled) { + if (drainer == null) { + try { + Optional producedChannelState = + reader.getProducedChannelState(); + if (producedChannelState.isPresent()) { + producedChannelState.get().release(); + } + } catch (Throwable ignored) { + // Preserve the original recovery failure. + } + } + physicalChannelsFuture.completeExceptionally(t); + } + return; + } + + if (drainer == null) { + return; + } + try { + drainer.drain(); + } catch (Throwable t) { + asyncExceptionHandler.handleAsyncException( + "Unable to drain recovered channel state", t); + } finally { + try { + drainer.close(); + } catch (Throwable closeError) { + asyncExceptionHandler.handleAsyncException( + "Unable to close FetchedChannelStateDrainer after drain", closeError); + } + } + } + + /** + * Wires each gate's {@code requestPartitions()} to run on the mailbox once its state-consumed + * trigger fires. Used when checkpointing-during-recovery is disabled, so no physical-channel + * conversion needs to be tracked. + * + *

Returns the futures the recovery mailbox loop must await before transitioning to RUNNING. + */ + private List> wireGateConversion(IndexedInputGate[] inputGates) { + List> recoveredFutures = new ArrayList<>(inputGates.length); + for (InputGate inputGate : inputGates) { + CompletableFuture requestPartitionsTrigger = inputGate.getStateConsumedFuture(); + recoveredFutures.add(requestPartitionsTrigger); + requestPartitionsTrigger.thenRun( + () -> + mainMailboxExecutor.execute( + inputGate::requestPartitions, "Input gate request partitions")); + } + return recoveredFutures; + } + + /** + * Wires each gate's {@code requestPartitions()} to run on the mailbox once recovery setup + * completes, and aggregates the per-gate completions into {@code physicalChannelsFuture}. Used + * when checkpointing-during-recovery is enabled. The trigger stays synchronous (no {@code + * *Async}): completing on the {@code channelIOExecutor} that fired {@code + * recoverySetupCompleteFuture} would let the poison mail outrun the suspend callback. + * + *

Returns the futures the recovery mailbox loop must await before transitioning to RUNNING. + */ + private List> wireGateConversionWithCheckpointing( + IndexedInputGate[] inputGates, + CompletableFuture> physicalChannelsFuture) { + List> recoveredFutures = new ArrayList<>(inputGates.length); + // Keep the recovery mailbox loop alive until physical channels are converted; otherwise a + // checkpoint barrier mail could block on the channels future that only a later conversion + // mail can complete. + if (inputGates.length > 0) { + recoveredFutures.add(physicalChannelsFuture); + } + recoveredFutures.add(recoverySetupCompleteFuture); + CompletableFuture gateConverted = new CompletableFuture<>(); + recoverySetupCompleteFuture.thenRun( + () -> + mainMailboxExecutor.execute( + () -> { + try { + for (InputGate inputGate : inputGates) { + inputGate.requestPartitions(); + } + gateConverted.complete(null); + } catch (Throwable t) { + gateConverted.completeExceptionally(t); + throw t; + } + }, + "Input gate request partitions")); + gateConverted + .thenApply(ignored -> collectPhysicalChannels(inputGates)) + .whenComplete( + (physicalChannels, failure) -> { + if (failure != null) { + physicalChannelsFuture.completeExceptionally(failure); + } else { + physicalChannelsFuture.complete(physicalChannels); + } + }); + return recoveredFutures; + } + + private static List collectPhysicalChannels(InputGate[] inputGates) { + List channels = new ArrayList<>(); + for (InputGate gate : inputGates) { + int numberOfInputChannels = gate.getNumberOfInputChannels(); + for (int i = 0; i < numberOfInputChannels; i++) { + InputChannel channel = gate.getChannel(i); + if (channel instanceof RecoverableInputChannel) { + channels.add((RecoverableInputChannel) channel); + } + } + } + return channels; + } + + /** + * Returns a trigger that resolves the real implementation lazily: the barrier handler is built + * before the spill drainer exists, so this defers reading {@link #recoverySetupCompleteFuture} + * until a checkpoint actually fires, by which point recovery setup has completed it. Resolving + * at construction time would block; an unresolved future at snapshot time means an invariant + * broke, so it fails loud. + */ + public RecoveryCheckpointTrigger getRecoveryCheckpointTrigger() { + return cpId -> { + checkState( + recoverySetupCompleteFuture.isDone(), + "Recovery checkpoint trigger is not resolved at checkpoint start."); + return recoverySetupCompleteFuture.getNow(null).snapshotAndInsertBarriers(cpId); + }; + } + private void ensureNotCanceled() { if (canceled) { throw new CancelTaskException(); @@ -1996,7 +2181,8 @@ protected RecordFilterContext createRecordFilterContext() { boolean checkpointingDuringRecoveryEnabled = CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); if (!checkpointingDuringRecoveryEnabled) { - return RecordFilterContext.disabled(); + return RecordFilterContext.disabled( + getEnvironment().getIOManager().getSpillingDirectoriesPaths()); } ClassLoader cl = getUserCodeClassLoader(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java new file mode 100644 index 0000000000000..b011d282cb78c --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +class ChannelIOExecutorDrainSubmissionTest { + + @TempDir Path tempDir; + + @Test + void testFilterOnSubmitsDrainAfterConversion() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, new byte[] {1, 2, 3}); + + CapturingChannel chan = new CapturingChannel(cInfo); + List all = new ArrayList<>(); + all.add(chan); + FetchedChannelStateDrainer drainer = + new FetchedChannelStateDrainer(state, CompletableFuture.completedFuture(all)); + + ExecutorService channelIOExecutor = Executors.newSingleThreadExecutor(); + try { + CompletableFuture done = new CompletableFuture<>(); + channelIOExecutor.execute( + () -> { + try { + drainer.drain(); + done.complete(null); + } catch (Throwable t) { + done.completeExceptionally(t); + } finally { + try { + drainer.close(); + } catch (IOException ignore) { + } + } + }); + + done.get(5, TimeUnit.SECONDS); + assertThat(chan.dataDeliveries).isGreaterThan(0); + assertThat(chan.finishCalled).isTrue(); + } finally { + channelIOExecutor.shutdownNow(); + assertThat(channelIOExecutor.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + } + } + + @Test + void testDrainExceptionBubblesViaAsyncExceptionHandler() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, new byte[] {1, 2, 3}); + + RecoverableInputChannel chan = + new RecoverableInputChannel() { + @Override + public InputChannelInfo getChannelInfo() { + return cInfo; + } + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + throw new RuntimeException("boom"); + } + + @Override + public void finishRecoveredBufferDelivery() {} + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) { + throw new RuntimeException("boom"); + } + + @Override + public Buffer requestRecoveryBufferBlocking() { + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(64); + return new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + } + + @Override + public void onRecoveredStateConsumed() {} + }; + + List all = new ArrayList<>(); + all.add(chan); + FetchedChannelStateDrainer drainer = + new FetchedChannelStateDrainer(state, CompletableFuture.completedFuture(all)); + + CountDownLatch handlerCalled = new CountDownLatch(1); + AtomicReference captured = new AtomicReference<>(); + ExecutorService channelIOExecutor = Executors.newSingleThreadExecutor(); + try { + channelIOExecutor.execute( + () -> { + try { + drainer.drain(); + } catch (Throwable t) { + captured.set(t); + handlerCalled.countDown(); + } finally { + try { + drainer.close(); + } catch (IOException ignore) { + } + } + }); + + assertThat(handlerCalled.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(captured.get()).isInstanceOf(RuntimeException.class); + assertThat(captured.get().getMessage()).isEqualTo("boom"); + } finally { + channelIOExecutor.shutdownNow(); + assertThat(channelIOExecutor.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + } + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private FetchedChannelState writeRecords(InputChannelInfo ch, byte[] payload) + throws IOException { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, payload, payload.length); + return writer.getChannelState(); + } + } + + private static final class CapturingChannel implements RecoverableInputChannel { + private final InputChannelInfo channelInfo; + int dataDeliveries = 0; + boolean finishCalled = false; + + CapturingChannel(InputChannelInfo channelInfo) { + this.channelInfo = channelInfo; + } + + @Override + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + if (buffer.isBuffer()) { + dataDeliveries++; + } + } + + @Override + public void finishRecoveredBufferDelivery() { + finishCalled = true; + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) {} + + @Override + public Buffer requestRecoveryBufferBlocking() { + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(64); + return new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + } + + @Override + public void onRecoveredStateConsumed() {} + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java new file mode 100644 index 0000000000000..32f83760c0104 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; + +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Stress test for {@link FetchedChannelStateDrainer} drain / snapshot critical-section atomicity. + * Spawns one drain thread alongside multiple concurrent {@code snapshotAndInsertBarriers} calls and + * asserts that no buffer is delivered both via drain and snapshot, and that every segment appears + * in exactly one place. + */ +class FetchedChannelStateDrainerConcurrencyTest { + + private static final int RECORD_COUNT = 500; + private static final int SNAPSHOTS = 50; + private static final int CHANNEL_COUNT = 2; + + @TempDir Path tempDir; + + @RepeatedTest(3) + void testDrainAndSnapshotConcurrentAtomicity() throws Exception { + Path runDir = Files.createTempDirectory(tempDir, "drain-stress-"); + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(runDir)) { + for (int i = 0; i < RECORD_COUNT; i++) { + InputChannelInfo ch = (i % 2 == 0) ? c0 : c1; + writer.writeRecord(ch, payloadFor(i), 8); + } + state = writer.getChannelState(); + } + + ThreadSafeRecordingChannel chan0 = new ThreadSafeRecordingChannel(c0); + ThreadSafeRecordingChannel chan1 = new ThreadSafeRecordingChannel(c1); + List all = new ArrayList<>(); + all.add(chan0); + all.add(chan1); + + FetchedChannelStateDrainer drainer = + new FetchedChannelStateDrainer(state, CompletableFuture.completedFuture(all)); + + ExecutorService io = Executors.newSingleThreadExecutor(); + AtomicReference drainError = new AtomicReference<>(); + + Future drainFuture = + io.submit( + () -> { + try { + drainer.drain(); + } catch (Throwable t) { + drainError.set(t); + } + }); + + // Take snapshots concurrently while drain runs. + List snapshots = new ArrayList<>(); + for (int i = 0; i < SNAPSHOTS; i++) { + snapshots.add(drainer.snapshotAndInsertBarriers(i + 1)); + Thread.yield(); + } + + drainFuture.get(60, TimeUnit.SECONDS); + io.shutdown(); + assertThat(io.awaitTermination(10, TimeUnit.SECONDS)).isTrue(); + if (drainError.get() != null) { + throw new AssertionError("drain failed", drainError.get()); + } + + // Close all snapshots + for (FetchedChannelStateReader snap : snapshots) { + snap.close(); + } + + // Both channels must have received finish calls + assertThat(chan0.barrierCount()).isEqualTo(chan1.barrierCount()); + + drainer.close(); + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private static byte[] payloadFor(int id) { + byte[] out = new byte[8]; + out[0] = (byte) (id & 0xff); + out[1] = (byte) ((id >> 8) & 0xff); + out[2] = (byte) ((id >> 16) & 0xff); + out[3] = (byte) ((id >> 24) & 0xff); + Arrays.fill(out, 4, 8, (byte) 0xCC); + return out; + } + + private static final class ThreadSafeRecordingChannel implements RecoverableInputChannel { + private final InputChannelInfo channelInfo; + private final List data = new ArrayList<>(); + private final List barriers = new ArrayList<>(); + + ThreadSafeRecordingChannel(InputChannelInfo channelInfo) { + this.channelInfo = channelInfo; + } + + @Override + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + @Override + public synchronized void onRecoveredStateBuffer(Buffer buffer) { + if (buffer.isBuffer()) { + data.add(buffer); + } else { + barriers.add(buffer); + } + } + + @Override + public synchronized void finishRecoveredBufferDelivery() {} + + @Override + public synchronized void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) + throws IOException { + barriers.add( + EventSerializer.toBuffer(new RecoveryCheckpointBarrier(checkpointId), false)); + } + + @Override + public Buffer requestRecoveryBufferBlocking() { + // Use a buffer large enough for a full segment + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(4096); + return new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + } + + @Override + public synchronized void onRecoveredStateConsumed() {} + + synchronized int barrierCount() { + return barriers.size(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java new file mode 100644 index 0000000000000..c78d5b67f4770 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java @@ -0,0 +1,475 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; +import org.apache.flink.runtime.event.AbstractEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.consumer.RecoverableInputChannel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link FetchedChannelStateDrainer}: drain demux, finish ordering, snapshot start + * position, barrier insertion, and edge cases (drain-finished, channel-not-in-recovery). + */ +class FetchedChannelStateDrainerTest { + + @TempDir Path tempDir; + + @Test + void testDrainEndToEnd() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(1), payload(2), payload(3)); + + RecordingChannel rec = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, rec); + + drainer.drain(); + drainer.close(); + + // All segment bodies must be delivered as buffer(s); at least 3 non-empty deliveries + // because the segment body contains 3 records but they may be batched into fewer buffers. + assertThat(rec.recovered).isNotEmpty(); + assertThat(rec.finishCalls).isEqualTo(1); + } + + @Test + void testDrainSegmentLargerThanBufferSplitsIntoFullChunksThenPartialTail() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + + // Buffer capacity deliberately smaller than the segment body so the drainer must fill + // multiple buffers and a final partial tail. 50 bytes over a 16-byte buffer => 16+16+16+2. + int bufferCapacity = 16; + byte[] body = sequentialBytes(50); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + // Pass-through so the segment body equals the verbatim bytes (no length framing). + writer.writePassThrough(cInfo, body, 0, body.length); + state = writer.getChannelState(); + } + + RecordingChannel rec = new RecordingChannel(cInfo, bufferCapacity); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, rec); + + drainer.drain(); + drainer.close(); + + // ceil(50 / 16) = 4 buffers delivered. + assertThat(rec.recovered).hasSize(4); + // Every buffer except the last is filled to capacity; the last carries the remainder. + for (int i = 0; i < rec.recovered.size() - 1; i++) { + assertThat(rec.recovered.get(i).getSize()).isEqualTo(bufferCapacity); + } + assertThat(rec.recovered.get(rec.recovered.size() - 1).getSize()) + .isEqualTo(body.length % bufferCapacity); + + // Buffers concatenated in delivery order must reproduce the segment body byte-for-byte. + assertThat(concat(rec.recovered)).isEqualTo(body); + assertThat(rec.finishCalls).isEqualTo(1); + } + + @Test + void testDrainSegmentExactMultipleOfBufferHasNoPartialTail() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + + // Body length is an exact multiple of the buffer capacity: the final read hits EOF on a + // freshly requested buffer, which must be recycled rather than delivered empty. + int bufferCapacity = 16; + byte[] body = sequentialBytes(bufferCapacity * 3); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writePassThrough(cInfo, body, 0, body.length); + state = writer.getChannelState(); + } + + RecordingChannel rec = new RecordingChannel(cInfo, bufferCapacity); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, rec); + + drainer.drain(); + drainer.close(); + + // Exactly 3 full buffers, no trailing empty buffer. + assertThat(rec.recovered).hasSize(3); + for (Buffer b : rec.recovered) { + assertThat(b.getSize()).isEqualTo(bufferCapacity); + } + assertThat(concat(rec.recovered)).isEqualTo(body); + assertThat(rec.finishCalls).isEqualTo(1); + } + + @Test + void testDrainDemuxByChannelInfo() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, payload(11), payload(11).length); + writer.writeRecord(c1, payload(22), payload(22).length); + writer.writeRecord(c0, payload(33), payload(33).length); + writer.writeRecord(c1, payload(44), payload(44).length); + state = writer.getChannelState(); + } + + RecordingChannel chan0 = new RecordingChannel(c0); + RecordingChannel chan1 = new RecordingChannel(c1); + FetchedChannelStateDrainer drainer = newDrainer(state, c0, chan0, c1, chan1); + + drainer.drain(); + drainer.close(); + + // Each channel must receive some data buffers + assertThat(chan0.recovered).isNotEmpty(); + assertThat(chan1.recovered).isNotEmpty(); + // Both channels must have finish called + assertThat(chan0.finishCalls).isEqualTo(1); + assertThat(chan1.finishCalls).isEqualTo(1); + } + + @Test + void testDrainCallsFinishAfterAllBufferDeliveries() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, payload(1), payload(1).length); + writer.writeRecord(c1, payload(2), payload(2).length); + state = writer.getChannelState(); + } + + int[] seq = {0}; + RecordingChannel chan0 = new RecordingChannel(c0, seq); + RecordingChannel chan1 = new RecordingChannel(c1, seq); + FetchedChannelStateDrainer drainer = newDrainer(state, c0, chan0, c1, chan1); + + drainer.drain(); + drainer.close(); + + int maxDataSeq = Math.max(chan0.maxDataSeq, chan1.maxDataSeq); + int minFinishSeq = Math.min(chan0.finishSeq, chan1.finishSeq); + assertThat(maxDataSeq).isLessThan(minFinishSeq); + } + + @Test + void testSnapshotCoversAllSegmentsBeforeDrain() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(5), payload(6)); + + RecordingChannel chan = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, chan); + + long cpId = 42L; + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(cpId); + + // Snapshot must cover all segments (at least 1 segment for cInfo). + // The sequential reader requires each segment body to be fully consumed before advancing, + // mirroring the real consumer (ChannelStateCheckpointWriter#writeInputFromSpill). + int count = 0; + Optional next; + while ((next = snap.nextSegment()).isPresent()) { + drainBody(next.get().bodyStream()); + count++; + } + snap.close(); + assertThat(count).isGreaterThan(0); + drainer.close(); + } + + @Test + void testSnapshotInsertsBarrierPerChannel() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, payload(1), payload(1).length); + writer.writeRecord(c1, payload(2), payload(2).length); + state = writer.getChannelState(); + } + + RecordingChannel chan0 = new RecordingChannel(c0); + RecordingChannel chan1 = new RecordingChannel(c1); + FetchedChannelStateDrainer drainer = newDrainer(state, c0, chan0, c1, chan1); + + long cpId = 7L; + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(cpId); + snap.close(); + + assertThat(chan0.recovered).hasSize(1); + assertThat(chan1.recovered).hasSize(1); + assertThat(extractRecoveryBarrierCheckpointId(chan0.recovered.get(0))).isEqualTo(cpId); + assertThat(extractRecoveryBarrierCheckpointId(chan1.recovered.get(0))).isEqualTo(cpId); + drainer.close(); + } + + @Test + void testSnapshotInsertsBarrierWhenChannelInRecoveryEvenIfDiskSliceEmpty() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(1)); + + RecordingChannel chan = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, chan); + + drainer.drain(); + // Drain finished; simulate the channel still in recovery + chan.inRecovery = true; + int recoveredBefore = chan.recovered.size(); + + long cpId = 6L; + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(cpId); + assertThat(snap.nextSegment()).isEmpty(); + snap.close(); + + // Barrier must be inserted even though disk slice is empty + assertThat(chan.recovered).hasSize(recoveredBefore + 1); + assertThat(extractRecoveryBarrierCheckpointId(chan.recovered.get(recoveredBefore))) + .isEqualTo(cpId); + drainer.close(); + } + + @Test + void testSnapshotInsertsBarrierOnlyForChannelsStillInRecovery() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, payload(1), payload(1).length); + state = writer.getChannelState(); + } + + RecordingChannel chan0 = new RecordingChannel(c0); + RecordingChannel chan1 = new RecordingChannel(c1); + chan1.inRecovery = false; + + FetchedChannelStateDrainer drainer = newDrainer(state, c0, chan0, c1, chan1); + + long cpId = 11L; + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(cpId); + snap.close(); + + assertThat(chan0.recovered).hasSize(1); + assertThat(extractRecoveryBarrierCheckpointId(chan0.recovered.get(0))).isEqualTo(cpId); + assertThat(chan1.recovered).isEmpty(); + drainer.close(); + } + + @Test + void testSnapshotReturnsEmptyWhenDrainFinishedAndNotInRecovery() throws Exception { + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(1), payload(2)); + + RecordingChannel chan = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, chan); + + drainer.drain(); + chan.inRecovery = false; + int recoveredBefore = chan.recovered.size(); + + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(99L); + assertThat(snap.nextSegment()).isEmpty(); + snap.close(); + + // No barrier added since channel left recovery + assertThat(chan.recovered).hasSize(recoveredBefore); + drainer.close(); + } + + @Test + void testSnapshotAfterDrainerClosedReturnsEmptyWithoutTouchingClosedRootReader() + throws Exception { + // Mirrors production order: drain() then close() (which closes the root reader) run before + // a + // late checkpoint fires snapshotAndInsertBarriers. The drain-finished flag must + // short-circuit + // so the closed root reader is never snapshotted. + InputChannelInfo cInfo = new InputChannelInfo(0, 0); + FetchedChannelState state = writeRecords(cInfo, payload(1), payload(2)); + + RecordingChannel chan = new RecordingChannel(cInfo); + FetchedChannelStateDrainer drainer = newDrainer(state, cInfo, chan); + + drainer.drain(); + drainer.close(); + + FetchedChannelStateReader snap = drainer.snapshotAndInsertBarriers(99L); + assertThat(snap.nextSegment()).isEmpty(); + snap.close(); + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private FetchedChannelState writeRecords(InputChannelInfo ch, byte[]... payloads) + throws IOException { + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + for (byte[] p : payloads) { + writer.writeRecord(ch, p, p.length); + } + return writer.getChannelState(); + } + } + + private FetchedChannelStateDrainer newDrainer( + FetchedChannelState state, Object... infoChannelPairs) { + List all = new ArrayList<>(); + for (int i = 0; i < infoChannelPairs.length; i += 2) { + all.add((RecoverableInputChannel) infoChannelPairs[i + 1]); + } + return new FetchedChannelStateDrainer(state, CompletableFuture.completedFuture(all)); + } + + private static long extractRecoveryBarrierCheckpointId(Buffer buffer) throws IOException { + AbstractEvent event = + EventSerializer.fromBuffer( + buffer, RecoveryCheckpointBarrier.class.getClassLoader()); + buffer.setReaderIndex(0); + assertThat(event).isInstanceOf(RecoveryCheckpointBarrier.class); + return ((RecoveryCheckpointBarrier) event).getCheckpointId(); + } + + private static byte[] payload(int id) { + return new byte[] {(byte) (id & 0xff), (byte) ((id >> 8) & 0xff), (byte) 0xAB, (byte) 0xCD}; + } + + /** Builds {@code n} bytes whose values count up modulo 256, so order mismatches are visible. */ + private static byte[] sequentialBytes(int n) { + byte[] out = new byte[n]; + for (int i = 0; i < n; i++) { + out[i] = (byte) i; + } + return out; + } + + /** Concatenates the readable bytes of the given buffers in order. */ + private static byte[] concat(List buffers) { + java.io.ByteArrayOutputStream out = new java.io.ByteArrayOutputStream(); + for (Buffer b : buffers) { + java.nio.ByteBuffer nio = b.getNioBufferReadable(); + byte[] chunk = new byte[nio.remaining()]; + nio.get(chunk); + out.write(chunk, 0, chunk.length); + } + return out.toByteArray(); + } + + /** Fully consumes a segment body so the sequential reader may advance to the next segment. */ + private static void drainBody(InputStream body) throws IOException { + byte[] buf = new byte[256]; + while (body.read(buf) != -1) { + // discard + } + } + + // ------------------------------------------------------------------------------------------- + // RecordingChannel stub + // ------------------------------------------------------------------------------------------- + + private static final int DEFAULT_RECOVERY_BUFFER_CAPACITY = 4096; + + private static final class RecordingChannel implements RecoverableInputChannel { + private final InputChannelInfo channelInfo; + final List recovered = new ArrayList<>(); + int finishCalls = 0; + private final int[] sequence; + private final int bufferCapacity; + int maxDataSeq = Integer.MIN_VALUE; + int finishSeq = -1; + boolean inRecovery = true; + + RecordingChannel(InputChannelInfo channelInfo) { + this(channelInfo, null, DEFAULT_RECOVERY_BUFFER_CAPACITY); + } + + RecordingChannel(InputChannelInfo channelInfo, int[] sharedSequence) { + this(channelInfo, sharedSequence, DEFAULT_RECOVERY_BUFFER_CAPACITY); + } + + RecordingChannel(InputChannelInfo channelInfo, int bufferCapacity) { + this(channelInfo, null, bufferCapacity); + } + + RecordingChannel(InputChannelInfo channelInfo, int[] sharedSequence, int bufferCapacity) { + this.channelInfo = channelInfo; + this.sequence = sharedSequence; + this.bufferCapacity = bufferCapacity; + } + + @Override + public InputChannelInfo getChannelInfo() { + return channelInfo; + } + + @Override + public void onRecoveredStateBuffer(Buffer buffer) { + recovered.add(buffer); + if (sequence != null) { + maxDataSeq = Math.max(maxDataSeq, ++sequence[0]); + } + } + + @Override + public void finishRecoveredBufferDelivery() { + finishCalls++; + if (sequence != null) { + finishSeq = ++sequence[0]; + } + } + + @Override + public void insertRecoveryCheckpointBarrierIfInRecovery(long checkpointId) + throws IOException { + if (inRecovery) { + recovered.add( + EventSerializer.toBuffer( + new RecoveryCheckpointBarrier(checkpointId), false)); + } + } + + @Override + public Buffer requestRecoveryBufferBlocking() { + MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(bufferCapacity); + return new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); + } + + @Override + public void onRecoveredStateConsumed() {} + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderTest.java new file mode 100644 index 0000000000000..a8bb14160f14e --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateReaderTest.java @@ -0,0 +1,532 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link FetchedChannelStateReader}: sequential segment scanning, body boundedness, + * cross-file transparency, snapshot derivation, and fail-loud on truncated segments. + * + *

Segment boundaries are self-described in disk headers; no in-memory segment locator table is + * used. + */ +class FetchedChannelStateReaderTest { + + @TempDir Path tempDir; + + // ------------------------------------------------------------------------------------------- + // Segment iteration + // ------------------------------------------------------------------------------------------- + + @Test + void testIteratorEmptyWhenNoDataWritten() throws Exception { + // A writer that never spills produces no state; an empty state has no segments. + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + try (FetchedChannelStateReader reader = state.reader()) { + assertThat(reader.nextSegment()).isEmpty(); + } + } + + @Test + void testMultipleIteratorIteratedInOrder() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, bytes(10), 1); + writer.writeRecord(c1, bytes(20), 1); + writer.writeRecord(c0, bytes(30), 1); + state = writer.getChannelState(); + } + + List channels = new ArrayList<>(); + try (FetchedChannelStateReader reader = state.reader()) { + Optional next; + while ((next = reader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + channels.add(seg.channelInfo()); + readAll(seg.bodyStream()); + } + } + + // Segments are produced at channel switches: c0, c1, c0 + assertThat(channels).containsExactly(c0, c1, c0); + } + + // ------------------------------------------------------------------------------------------- + // Body boundedness: body() stops exactly at segment end + // ------------------------------------------------------------------------------------------- + + @Test + void testBodyReturnsMinus1AtSegmentEnd() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1, 2), 2); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader reader = state.reader()) { + SpillSegment seg = reader.nextSegment().orElseThrow(AssertionError::new); + InputStream body = seg.bodyStream(); + // Read exactly length bytes + byte[] data = new byte[seg.length()]; + int totalRead = 0; + while (totalRead < data.length) { + int n = body.read(data, totalRead, data.length - totalRead); + assertThat(n).isGreaterThan(0); + totalRead += n; + } + // Next read must return EOF + assertThat(body.read()).isEqualTo(-1); + } + } + + // ------------------------------------------------------------------------------------------- + // Cross-file transparency + // ------------------------------------------------------------------------------------------- + + @Test + void testCrossFileTransparencyWhenRotationOccurs() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + // Use tiny rotation threshold so first segment triggers a file rotation. + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir, 1L /* 1 byte threshold */)) { + writer.writeRecord(c0, bytes(10, 11, 12), 3); + writer.writeRecord(c1, bytes(20, 21), 2); + state = writer.getChannelState(); + } + + // Two segments in different files. + assertThat(state.files()).hasSize(2); + + List channels = new ArrayList<>(); + try (FetchedChannelStateReader reader = state.reader()) { + Optional next; + while ((next = reader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + channels.add(seg.channelInfo()); + // Body read must not throw even if the segment is in a different file. + readAll(seg.bodyStream()); + } + } + + assertThat(channels).containsExactly(c0, c1); + } + + // ------------------------------------------------------------------------------------------- + // Snapshot: independent reader with correct start position + // ------------------------------------------------------------------------------------------- + + @Test + void testSnapshotCoversAllIteratorWhenNothingConsumed() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, bytes(1), 1); + writer.writeRecord(c1, bytes(2), 1); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader root = state.reader()) { + // Snapshot before consuming anything + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + List channels = new ArrayList<>(); + Optional next; + while ((next = snap.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + channels.add(seg.channelInfo()); + readAll(seg.bodyStream()); + } + assertThat(channels).containsExactly(c0, c1); + } + } + } + + @Test + void testSnapshotAfterFullSegmentConsumedSkipsThatSegment() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(c0, bytes(1), 1); + writer.writeRecord(c1, bytes(2), 1); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader root = state.reader()) { + // Consume and commit first segment + SpillSegment first = root.nextSegment().orElseThrow(AssertionError::new); + readAll(first.bodyStream()); + first.commit(); + + // Snapshot must start from second segment + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + List channels = new ArrayList<>(); + Optional next; + while ((next = snap.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + channels.add(seg.channelInfo()); + readAll(seg.bodyStream()); + } + assertThat(channels).containsExactly(c1); + } + } + } + + @Test + void testSnapshotFromMidSegmentStartsAtCommittedByteOffset() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + // Two records in the same channel -> one segment + writer.writeRecord(ch, bytes(10, 11), 2); + state = writer.getChannelState(); + } + // Verify: one file with one segment + assertThat(state.files()).hasSize(1); + + try (FetchedChannelStateReader root = state.reader()) { + SpillSegment seg = root.nextSegment().orElseThrow(AssertionError::new); + int fullLength = seg.length(); + InputStream body = seg.bodyStream(); + + // Read only 1 byte without committing, then snapshot — snapshot should start from 0 + // (no bytes committed yet). + body.read(); + + try (FetchedChannelStateReader snapBeforeCommit = root.snapshot().reader()) { + SpillSegment snapSeg = + snapBeforeCommit.nextSegment().orElseThrow(AssertionError::new); + assertThat(snapSeg.length()).isEqualTo(fullLength); + readAll(snapSeg.bodyStream()); + } + + // Read rest of body and commit + readAll(body); + seg.commit(); + + // After commit the snapshot must be empty + try (FetchedChannelStateReader snapAfterCommit = root.snapshot().reader()) { + assertThat(snapAfterCommit.nextSegment()).isEmpty(); + } + } + } + + @Test + void testSnapshotAfterPartialCommitReadsRemainingBodyTail() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + // One segment whose body is a single pass-through blob of known bytes. + writer.writePassThrough(ch, bytes(1, 2, 3, 4, 5, 6, 7, 8), 0, 8); + state = writer.getChannelState(); + } + assertThat(state.files()).hasSize(1); + + try (FetchedChannelStateReader root = state.reader()) { + SpillSegment seg = root.nextSegment().orElseThrow(AssertionError::new); + int fullLength = seg.length(); + InputStream body = seg.bodyStream(); + + // Read and commit only a 3-byte prefix. + byte[] prefix = new byte[3]; + assertThat(body.read(prefix)).isEqualTo(3); + seg.commit(); + + // Snapshot must resume mid-segment and yield exactly the remaining tail bytes. + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + SpillSegment snapSeg = snap.nextSegment().orElseThrow(AssertionError::new); + assertThat(snapSeg.channelInfo()).isEqualTo(ch); + assertThat(snapSeg.length()).isEqualTo(fullLength - 3); + byte[] tail = readAll(snapSeg.bodyStream()); + assertThat(tail).isEqualTo(bytes(4, 5, 6, 7, 8)); + assertThat(snap.nextSegment()).isEmpty(); + } + } + } + + @Test + void testSnapshotResumesPartialSegmentAcrossFileBoundary() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + // Tiny rotation threshold so the two segments land in separate files. + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir, 1L)) { + writer.writePassThrough(c0, bytes(1, 2, 3, 4), 0, 4); + writer.writePassThrough(c1, bytes(5, 6, 7), 0, 3); + state = writer.getChannelState(); + } + assertThat(state.files()).hasSize(2); + + try (FetchedChannelStateReader root = state.reader()) { + SpillSegment first = root.nextSegment().orElseThrow(AssertionError::new); + // Commit a 1-byte prefix of the file-0 segment. + first.bodyStream().read(new byte[1]); + first.commit(); + + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + SpillSegment resumed = snap.nextSegment().orElseThrow(AssertionError::new); + assertThat(resumed.channelInfo()).isEqualTo(c0); + assertThat(readAll(resumed.bodyStream())).isEqualTo(bytes(2, 3, 4)); + + // Crossing into file 1 must reset the skip to 0. + SpillSegment following = snap.nextSegment().orElseThrow(AssertionError::new); + assertThat(following.channelInfo()).isEqualTo(c1); + assertThat(readAll(following.bodyStream())).isEqualTo(bytes(5, 6, 7)); + + assertThat(snap.nextSegment()).isEmpty(); + } + } + } + + @Test + void testRootDrainViaRepeatedCommitsTerminatesAndFinalSnapshotEmpty() throws Exception { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writePassThrough(c0, bytes(1, 2), 0, 2); + writer.writePassThrough(c1, bytes(3, 4, 5), 0, 3); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader root = state.reader()) { + int count = 0; + Optional next; + while ((next = root.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + readAll(seg.bodyStream()); + seg.commit(); + count++; + } + assertThat(count).isEqualTo(2); + + // After draining everything, a snapshot must have nothing left. + try (FetchedChannelStateReader snap = root.snapshot().reader()) { + assertThat(snap.nextSegment()).isEmpty(); + } + } + } + + // ------------------------------------------------------------------------------------------- + // Fail-loud on truncated segment + // ------------------------------------------------------------------------------------------- + + @Test + void testBodyThrowsEOFExceptionOnTruncatedFile() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1, 2, 3, 4, 5, 6, 7, 8), 8); + state = writer.getChannelState(); + } + + // Truncate the spill file to just the header (12 bytes) so the body is missing. + Path spill = state.files().get(0); + byte[] headerOnly = Files.readAllBytes(spill); + // Keep only the 12-byte header, discard body. + Files.write( + spill, + java.util.Arrays.copyOf(headerOnly, AbstractSpillingHandler.SEGMENT_HEADER_BYTES), + StandardOpenOption.TRUNCATE_EXISTING); + + try (FetchedChannelStateReader reader = state.reader()) { + SpillSegment seg = reader.nextSegment().orElseThrow(AssertionError::new); + // bufferLength from header says > 0 bytes, but file has nothing after the header. + assertThatThrownBy(() -> readAll(seg.bodyStream())) + .isInstanceOfAny(EOFException.class, IOException.class); + } + } + + // ------------------------------------------------------------------------------------------- + // Reference counting: acquire/release via reader lifecycle + // ------------------------------------------------------------------------------------------- + + @Test + void testReaderAcquiresAndReleasesRefCount() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1), 1); + state = writer.getChannelState(); + } + + Path spill = state.files().get(0); + + FetchedChannelStateReader reader = state.reader(); + // Drop the handoff grant so the reader's grant is the only one outstanding. + state.release(); + assertThat(java.nio.file.Files.exists(spill)).isTrue(); + + reader.close(); + + // After closing the only reader, the file is cleaned up. + assertThat(java.nio.file.Files.exists(spill)).isFalse(); + } + + @Test + void testSnapshotKeepsFilesAliveUntilSnapshotClosed() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1), 1); + state = writer.getChannelState(); + } + + Path spill = state.files().get(0); + + FetchedChannelStateReader root = state.reader(); + FetchedChannelStateReader snap = root.snapshot().reader(); + // Drop the handoff grant so only the two reader grants remain outstanding. + state.release(); + + root.close(); // One grant released; file must still exist because snap holds another. + assertThat(java.nio.file.Files.exists(spill)).isTrue(); + + snap.close(); // Last grant released; file must be deleted. + assertThat(java.nio.file.Files.exists(spill)).isFalse(); + } + + // ------------------------------------------------------------------------------------------- + // New behaviour: first-call exemption, fail-loud on body not consumed, empty reader + // ------------------------------------------------------------------------------------------- + + @Test + void testFirstNextSegmentCallDoesNotRequirePreviousBodyConsumed() throws Exception { + // The "previous body must be fully consumed" rule must not fire on the very first call + // because there is no previous segment. + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(new InputChannelInfo(0, 0), bytes(1, 2), 2); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader reader = state.reader()) { + // Must not throw on the first call even though no body has been consumed before. + Optional seg = reader.nextSegment(); + assertThat(seg).isPresent(); + } + } + + @Test + void testNextSegmentThrowsWhenPreviousBodyNotFullyConsumed() throws Exception { + InputChannelInfo ch = new InputChannelInfo(0, 0); + InputChannelInfo ch2 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, bytes(1, 2, 3, 4), 4); + writer.writeRecord(ch2, bytes(5, 6), 2); + state = writer.getChannelState(); + } + + try (FetchedChannelStateReader reader = state.reader()) { + SpillSegment seg = reader.nextSegment().orElseThrow(AssertionError::new); + // Read only part of the body — do not exhaust it. + seg.bodyStream().read(); + + // Advancing to the next segment while the previous body is not fully consumed must fail + // loud. + assertThatThrownBy(reader::nextSegment) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Previous segment body not fully consumed"); + } + } + + @Test + void testEmptyReaderNextSegmentReturnsEmptyAndCloseIsClean() throws Exception { + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + try (FetchedChannelStateReader reader = state.reader()) { + // First call on an empty reader must return empty without throwing. + assertThat(reader.nextSegment()).isEmpty(); + // Closing must not throw. + } + } + + @Test + void testEmptyReaderHandsOutIndependentInstancesSoCloseDoesNotLeak() throws Exception { + // emptyReader() is obtained and closed once per checkpoint. close() is single-use (it flips + // the closed flag permanently), so each call must yield a fresh instance; otherwise the + // first consumer's close would make every later consumer's nextSegment() fail loud. + FetchedChannelStateReader first = FetchedChannelStateReader.emptyReader(); + assertThat(first.nextSegment()).isEmpty(); + first.close(); + + FetchedChannelStateReader second = FetchedChannelStateReader.emptyReader(); + assertThat(second).isNotSameAs(first); + // Must still work after the previously obtained empty reader was closed. + assertThat(second.nextSegment()).isEmpty(); + second.close(); + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private static byte[] readAll(InputStream in) throws IOException { + java.io.ByteArrayOutputStream out = new java.io.ByteArrayOutputStream(); + byte[] buf = new byte[256]; + int n; + while ((n = in.read(buf)) != -1) { + out.write(buf, 0, n); + } + return out.toByteArray(); + } + + private static byte[] bytes(int... values) { + byte[] arr = new byte[values.length]; + for (int i = 0; i < values.length; i++) { + arr[i] = (byte) values[i]; + } + return arr; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateRefCountTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateRefCountTest.java new file mode 100644 index 0000000000000..e75d0009c2f65 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateRefCountTest.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Verifies the reference counter on {@link FetchedChannelState}: acquire/release pairing, + * zero-triggered file deletion, idempotency, abort-path equivalence, and forced {@link + * FetchedChannelState#close()} cleanup. + */ +class FetchedChannelStateRefCountTest { + + @TempDir Path tempDir; + + @Test + void testAcquireReleaseCountsMatch() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + + // The produced state already holds one handoff grant. + state.acquire(); + assertFilesExist(files, true); + + state.release(); + assertFilesExist(files, true); + + state.release(); + assertFilesExist(files, false); + } + + @Test + void testReachingZeroDeletesFiles() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + // The produced state already holds one handoff grant. + assertFilesExist(files, true); + + state.release(); + assertFilesExist(files, false); + } + + @Test + void testReleaseAfterZeroIsNoOp() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + // Release the single handoff grant the produced state already holds. + state.release(); + assertFilesExist(files, false); + + // Extra releases past zero must be a no-op. + state.release(); + state.release(); + assertFilesExist(files, false); + } + + @Test + void testAbortPathReleasesViaSameRoute() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + + // The produced state already holds one handoff grant. + state.acquire(); + state.acquire(); + + state.release(); + state.release(); + assertFilesExist(files, true); + + state.release(); + assertFilesExist(files, false); + } + + @Test + void testForceCloseStillCleansFiles() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + // The produced state already holds one handoff grant. + state.acquire(); + assertFilesExist(files, true); + + state.close(); + assertFilesExist(files, false); + + // Double close must be a no-op. + state.close(); + + // Late release after close must not re-delete or throw. + state.release(); + assertFilesExist(files, false); + } + + @Test + void testReaderAcquiresAndReleasesOnClose() throws IOException { + FetchedChannelState state = newStateWithData(); + List files = state.files(); + + // Opening a reader acquires one grant; the produced state already holds the handoff grant. + FetchedChannelStateReader reader = state.reader(); + // Drop the handoff grant so the reader's grant is the only one outstanding. + state.release(); + assertFilesExist(files, true); + + // Closing releases that grant, which reaches zero and deletes files. + reader.close(); + assertFilesExist(files, false); + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private FetchedChannelState newStateWithData() throws IOException { + InputChannelInfo ch = new InputChannelInfo(0, 0); + try (TestSpillWriter writer = new TestSpillWriter(tempDir)) { + writer.writeRecord(ch, new byte[] {1, 2, 3}, 3); + writer.writeRecord(new InputChannelInfo(0, 1), new byte[] {4, 5}, 2); + return writer.getChannelState(); + } + } + + private static void assertFilesExist(List files, boolean expected) { + for (Path file : files) { + assertThat(Files.exists(file)) + .as("file " + file + " exists=" + expected) + .isEqualTo(expected); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateTest.java new file mode 100644 index 0000000000000..c80656b65536f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link FetchedChannelState} lifecycle and file list management. */ +class FetchedChannelStateTest { + + @TempDir Path tempDir; + + @Test + void testInitialStateIsEmpty() { + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + assertThat(state.files()).isEmpty(); + assertThat(state.isClosed()).isFalse(); + } + + @Test + void testFileListPreservesOrder() throws IOException { + Path file0 = tempDir.resolve("spill-0.bin"); + Path file1 = tempDir.resolve("spill-1.bin"); + + try (FetchedChannelState state = new FetchedChannelState(Arrays.asList(file0, file1))) { + assertThat(state.files()).containsExactly(file0, file1); + } + } + + @Test + void testFilesListIsUnmodifiable() throws IOException { + try (FetchedChannelState state = + new FetchedChannelState(Collections.singletonList(tempDir.resolve("f0.bin")))) { + assertThatThrownBy(() -> state.files().add(tempDir.resolve("f1.bin"))) + .isInstanceOf(UnsupportedOperationException.class); + } + } + + @Test + void testAcquireReleaseDoesNotDeleteFilesBeforeLastRelease() throws IOException { + Path realFile = tempDir.resolve("spill-0.bin"); + realFile.toFile().createNewFile(); + FetchedChannelState state = new FetchedChannelState(Collections.singletonList(realFile)); + + state.acquire(); + state.acquire(); + + state.release(); + // File must still exist after first release. + assertThat(realFile.toFile()).exists(); + + state.release(); + // Last release should delete the file. + assertThat(realFile.toFile()).doesNotExist(); + assertThat(state.isClosed()).isTrue(); + } + + @Test + void testCloseDeletesAllFiles() throws IOException { + Path file0 = tempDir.resolve("f0.bin"); + Path file1 = tempDir.resolve("f1.bin"); + file0.toFile().createNewFile(); + file1.toFile().createNewFile(); + + FetchedChannelState state = new FetchedChannelState(Arrays.asList(file0, file1)); + + state.close(); + + assertThat(file0.toFile()).doesNotExist(); + assertThat(file1.toFile()).doesNotExist(); + assertThat(state.isClosed()).isTrue(); + } + + @Test + void testCloseIsIdempotent() throws IOException { + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + state.close(); + assertThat(state.isClosed()).isTrue(); + // Second close must not throw. + state.close(); + assertThat(state.isClosed()).isTrue(); + } + + @Test + void testCloseAfterReleaseIsIdempotent() throws IOException { + FetchedChannelState state = new FetchedChannelState(Collections.emptyList()); + state.acquire(); + state.release(); + assertThat(state.isClosed()).isTrue(); + // close() after last release must be a no-op (no double-delete attempt). + state.close(); + assertThat(state.isClosed()).isTrue(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java new file mode 100644 index 0000000000000..1b9854361b05e --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; +import org.apache.flink.runtime.memory.MemoryManager; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +class RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest { + + private NetworkBufferPool pool; + + @AfterEach + void tearDown() { + if (pool != null) { + pool.destroy(); + pool = null; + } + } + + @Test + void testBufferPoolExhaustedBlocksRatherThanHeapAllocate() throws Exception { + int totalSegments = 4; + pool = new NetworkBufferPool(totalSegments, MemoryManager.DEFAULT_PAGE_SIZE); + RecoveredInputChannel channel = buildChannel(pool, totalSegments); + + for (int i = 0; i < totalSegments; i++) { + channel.requestBufferBlocking(); + } + + CountDownLatch entered = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + Thread blocker = + new Thread( + () -> { + try { + entered.countDown(); + result.set(channel.requestBufferBlocking()); + } catch (Exception ignored) { + // Thread will be interrupted at teardown. + } + }, + "blocking-requester"); + blocker.start(); + + assertThat(entered.await(5, TimeUnit.SECONDS)).isTrue(); + Thread.sleep(200); + assertThat(result.get()).as("buffer should not have been allocated").isNull(); + + blocker.interrupt(); + blocker.join(5_000); + } + + @Test + void testFilterOnPathTakesSameRouteAsFilterOff() throws Exception { + int exclusivePerChannel = 1; + int totalSegments = 4; + pool = new NetworkBufferPool(totalSegments, MemoryManager.DEFAULT_PAGE_SIZE); + + Buffer filterOnBuf = buildChannel(pool, exclusivePerChannel, true).requestBufferBlocking(); + Buffer filterOffBuf = + buildChannel(pool, exclusivePerChannel, false).requestBufferBlocking(); + + // Both must come from the pool — the BufferManager-owned recycler, not the + // FreeingBufferRecycler the heap-fallback used. + assertThat(filterOnBuf.getMemorySegment()).isNotNull(); + assertThat(filterOffBuf.getMemorySegment()).isNotNull(); + assertThat(filterOnBuf.getRecycler().getClass().getName()) + .doesNotContain("FreeingBufferRecycler"); + assertThat(filterOffBuf.getRecycler().getClass().getName()) + .doesNotContain("FreeingBufferRecycler"); + + filterOnBuf.recycleBuffer(); + filterOffBuf.recycleBuffer(); + } + + private RecoveredInputChannel buildChannel( + NetworkBufferPool segmentProvider, int exclusivePerChannel) { + return buildChannel(segmentProvider, exclusivePerChannel, true); + } + + private RecoveredInputChannel buildChannel( + NetworkBufferPool segmentProvider, + int exclusivePerChannel, + boolean checkpointingDuringRecoveryEnabled) { + try { + SingleInputGate inputGate = + new SingleInputGateBuilder() + .setSegmentProvider(segmentProvider) + .setCheckpointingDuringRecoveryEnabled( + checkpointingDuringRecoveryEnabled) + .build(); + return new RecoveredInputChannel( + inputGate, + 0, + new ResultPartitionID(), + new ResultSubpartitionIndexSet(0), + 0, + 0, + new SimpleCounter(), + new SimpleCounter(), + exclusivePerChannel) { + @Override + protected InputChannel toInputChannelInternal(boolean needsRecovery) { + throw new AssertionError("not expected during this test"); + } + }; + } catch (Exception e) { + throw new AssertionError("channel construction failed", e); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java index f64a4d9fb9cac..1eb305cf10a4d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java @@ -53,10 +53,12 @@ import org.apache.flink.util.function.SupplierWithException; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -78,6 +80,8 @@ /** ChannelPersistenceITCase. */ class ChannelPersistenceITCase { + @TempDir static Path tmpDir; + private static final Random RANDOM = new Random(System.currentTimeMillis()); private static final JobID JOB_ID = new JobID(); private static final JobVertexID JOB_VERTEX_ID = new JobVertexID(); @@ -120,7 +124,10 @@ void testReadWritten() throws Exception { try { int numChannels = 1; InputGate gate = buildGate(networkBufferPool, numChannels); - reader.readInputData(new InputGate[] {gate}, RecordFilterContext.disabled()); + reader.readInputData( + new InputGate[] {gate}, + RecordFilterContext.disabled( + new String[] {tmpDir.toAbsolutePath().toString()})); assertThat(collectBytes(gate::pollNext, BufferOrEvent::getBuffer)) .isEqualTo(inputChannelInfoData); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java index 60494f7acbe5c..27477ca8c6ef9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java @@ -36,9 +36,10 @@ class RecordFilterContextTest { @Test void testDisabledContextHasNoGates() { - RecordFilterContext disabled = RecordFilterContext.disabled(); + RecordFilterContext disabled = RecordFilterContext.disabled(new String[] {"/tmp"}); assertThat(disabled.getNumberOfGates()).isEqualTo(0); assertThat(disabled.isCheckpointingDuringRecoveryEnabled()).isFalse(); + assertThat(disabled.getTmpDirectories()).containsExactly("/tmp"); } @Test @@ -72,7 +73,7 @@ void testGetInputConfigThrowsForInvalidIndex() { InflightDataRescalingDescriptor.NO_RESCALE, 0, 128, - null, + new String[] {"/tmp"}, false, MemoryManager.DEFAULT_PAGE_SIZE); @@ -83,18 +84,29 @@ void testGetInputConfigThrowsForInvalidIndex() { } @Test - void testNullTmpDirectoriesConvertedToEmptyArray() { - RecordFilterContext context = - new RecordFilterContext( - new RecordFilterContext.InputFilterConfig[0], - InflightDataRescalingDescriptor.NO_RESCALE, - 0, - 128, - null, - false, - MemoryManager.DEFAULT_PAGE_SIZE); - - assertThat(context.getTmpDirectories()).isNotNull().isEmpty(); + void testNullOrEmptyTmpDirectoriesRejected() { + assertThatThrownBy( + () -> + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + null, + false, + MemoryManager.DEFAULT_PAGE_SIZE)) + .isInstanceOf(NullPointerException.class); + assertThatThrownBy( + () -> + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + new String[0], + false, + MemoryManager.DEFAULT_PAGE_SIZE)) + .isInstanceOf(IllegalArgumentException.class); } @Test @@ -111,7 +123,7 @@ void testIsAmbiguousWhenDisabled() { descriptor, 0, 128, - null, + new String[] {"/tmp"}, false, MemoryManager.DEFAULT_PAGE_SIZE); @@ -132,7 +144,7 @@ void testIsAmbiguousWhenEnabled() { descriptor, 0, 128, - null, + new String[] {"/tmp"}, true, MemoryManager.DEFAULT_PAGE_SIZE); @@ -152,7 +164,7 @@ void testIsAmbiguousForNonAmbiguousSubtask() { descriptor, 0, 128, - null, + new String[] {"/tmp"}, true, MemoryManager.DEFAULT_PAGE_SIZE); @@ -170,7 +182,7 @@ void testMemorySegmentSizeExposedAndValidated() { InflightDataRescalingDescriptor.NO_RESCALE, 0, 128, - null, + new String[] {"/tmp"}, false, MemoryManager.DEFAULT_PAGE_SIZE * 2); @@ -184,7 +196,7 @@ void testMemorySegmentSizeExposedAndValidated() { InflightDataRescalingDescriptor.NO_RESCALE, 0, 128, - null, + new String[] {"/tmp"}, false, 0)) .isInstanceOf(IllegalArgumentException.class); @@ -195,7 +207,7 @@ void testMemorySegmentSizeExposedAndValidated() { InflightDataRescalingDescriptor.NO_RESCALE, 0, 128, - null, + new String[] {"/tmp"}, false, -1)) .isInstanceOf(IllegalArgumentException.class); From d9fc48e9946bf198311f0b662770f01ca813ba3f Mon Sep 17 00:00:00 2001 From: Rui Fan <1996fanrui@gmail.com> Date: Fri, 22 May 2026 13:29:00 +0200 Subject: [PATCH 06/16] [FLINK-38544][checkpoint] Phase 5: checkpoint 3-step coordination - ChannelState dispatcher onCheckpointStartedForAllInputs implements Step 1 (snapshotAndInsertBarriers) -> Step 2 (per-input checkpointStarted) -> Step 3 (addInputDataFromSpill) -> cpId-completion release callback. - Hook AlternatingWaitingForFirstBarrierUnaligned.barrierReceived and AlternatingCollectingBarriers.alignedCheckpointTimeout into the dispatcher. - ChannelStateWriterImpl.addInputDataFromSpill: async demux by Chunk.channelInfo, empty snapshot inline early return, failures propagate via ChannelStateWriteResult. - Stream task pipelines (One/Two/Multiple) wire ChannelState through the InputProcessorUtil + SingleCheckpointBarrierHandler so the dispatcher hook reaches the right barrier-handler instance. - ITCases (relocated under flink-runtime to share the package with SpillFile): rescale + filter + large record OOM regression, UC during recovery. FLINK-38544 spilling v2 feature complete. Design: requirements/38544/phase5_coordination/design.md (cherry picked from commit 7badbd26b85ed84bd5161734c41f46687f14c6c8) --- .../channel/ChannelStateCheckpointWriter.java | 44 +++ .../channel/ChannelStateWriteRequest.java | 14 + .../channel/ChannelStateWriter.java | 11 + .../channel/ChannelStateWriterImpl.java | 7 + .../AlternatingCollectingBarriers.java | 5 +- ...natingWaitingForFirstBarrierUnaligned.java | 4 +- .../io/checkpointing/ChannelState.java | 50 +++ .../io/checkpointing/InputProcessorUtil.java | 37 ++- .../SingleCheckpointBarrierHandler.java | 15 +- .../tasks/MultipleInputStreamTask.java | 3 +- .../runtime/tasks/OneInputStreamTask.java | 3 +- .../runtime/tasks/TwoInputStreamTask.java | 3 +- ...teWriterImplAddInputDataFromSpillTest.java | 167 ++++++++++ .../channel/MockChannelStateWriter.java | 10 + ...eFilterLargeRecordOOMRegressionITCase.java | 156 ++++++++++ ...alignedCheckpointDuringRecoveryITCase.java | 110 +++++++ ...ingCollectingBarriersDispatchHookTest.java | 124 ++++++++ ...FirstBarrierUnalignedDispatchHookTest.java | 113 +++++++ .../ChannelStateDispatcherTest.java | 290 ++++++++++++++++++ .../TestBarrierHandlerFactory.java | 3 + 20 files changed, 1155 insertions(+), 14 deletions(-) create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplAddInputDataFromSpillTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RescaleFilterLargeRecordOOMRegressionITCase.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/UnalignedCheckpointDuringRecoveryITCase.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriersDispatchHookTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelStateDispatcherTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java index 4173bb7140e78..186c8146c3c70 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java @@ -19,6 +19,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -41,6 +42,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION; @@ -161,6 +163,48 @@ void writeInput( } } + void writeInputFromSpill( + JobVertexID jobVertexID, int subtaskIndex, FetchedChannelStateReader reader) { + if (isDone()) { + try { + reader.close(); + } catch (Exception ignored) { + } + return; + } + ChannelStatePendingResult pendingResult = + getChannelStatePendingResult(jobVertexID, subtaskIndex); + runWithChecks( + () -> { + checkState(!pendingResult.isAllInputsReceived()); + try { + String action = "ChannelStateCheckpointWriter#writeInputFromSpill"; + Optional next; + while ((next = reader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + long offset = checkpointStream.getPos(); + try (AutoCloseable ignored = + NetworkActionsLogger.measureIO(action, seg.channelInfo())) { + serializer.writeData(dataStream, seg.bodyStream(), seg.length()); + } + long size = checkpointStream.getPos() - offset; + pendingResult + .getInputChannelOffsets() + .computeIfAbsent( + seg.channelInfo(), unused -> new StateContentMetaInfo()) + .withDataAdded(offset, size); + NetworkActionsLogger.tracePersist( + action, + seg.length() + " bytes", + seg.channelInfo(), + checkpointId); + } + } finally { + reader.close(); + } + }); + } + void writeOutput( JobVertexID jobVertexID, int subtaskIndex, ResultSubpartitionInfo info, Buffer buffer) { try { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java index abef241c325b8..d1913df0416c1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java @@ -258,6 +258,20 @@ static ChannelStateWriteRequest abort( return new CheckpointAbortRequest(jobVertexID, subtaskIndex, checkpointId, cause); } + static ChannelStateWriteRequest replayInputDataFromSpill( + JobVertexID jobVertexID, + int subtaskIndex, + long checkpointId, + FetchedChannelStateReader reader) { + return new CheckpointInProgressRequest( + "writeInputFromSpill", + jobVertexID, + subtaskIndex, + checkpointId, + writer -> writer.writeInputFromSpill(jobVertexID, subtaskIndex, reader), + throwable -> reader.close()); + } + static ChannelStateWriteRequest registerSubtask(JobVertexID jobVertexID, int subtaskIndex) { return new SubtaskRegisterRequest(jobVertexID, subtaskIndex); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java index 6fee1402036d6..b603a53a09ebb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java @@ -190,6 +190,9 @@ void addOutputDataFuture( ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) throws IllegalArgumentException; + /** Records input-channel state from a spill file and takes ownership of {@code reader}. */ + void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader); + ChannelStateWriter NO_OP = new NoOpChannelStateWriter(); /** No-op implementation of {@link ChannelStateWriter}. */ @@ -231,6 +234,14 @@ public ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) { CompletableFuture.completedFuture(Collections.emptyList())); } + @Override + public void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader) { + try { + reader.close(); + } catch (Exception ignored) { + } + } + @Override public void close() {} } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java index 40d7ddffd1e18..21db97355db7c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java @@ -42,6 +42,7 @@ import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeInput; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeOutput; +import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.replayInputDataFromSpill; import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.write; /** @@ -235,6 +236,12 @@ public void finishOutput(long checkpointId) { enqueue(completeOutput(jobVertexID, subtaskIndex, checkpointId), false); } + @Override + public void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader) { + LOG.debug("{} replaying input data from spill, checkpoint {}", taskName, checkpointId); + enqueue(replayInputDataFromSpill(jobVertexID, subtaskIndex, checkpointId, reader), false); + } + @Override public void abort(long checkpointId, Throwable cause, boolean cleanup) { LOG.debug("{} aborting, checkpoint {}", taskName, checkpointId); diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java index 8ca37055bc388..c918f9db0ee8a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java @@ -21,7 +21,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; -import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; import java.io.IOException; @@ -44,9 +43,7 @@ public BarrierHandlerState alignedCheckpointTimeout( state.prioritizeAllAnnouncements(); CheckpointBarrier unalignedBarrier = checkpointBarrier.asUnaligned(); controller.initInputsCheckpoint(unalignedBarrier); - for (CheckpointableInput input : state.getInputs()) { - input.checkpointStarted(unalignedBarrier); - } + state.onCheckpointStartedForAllInputs(unalignedBarrier); controller.triggerGlobalCheckpoint(unalignedBarrier); return new AlternatingCollectingBarriersUnaligned(true, state); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java index af04f4f8107d2..1e12b4757e876 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java @@ -72,9 +72,7 @@ public BarrierHandlerState barrierReceived( CheckpointBarrier unalignedBarrier = checkpointBarrier.asUnaligned(); controller.initInputsCheckpoint(unalignedBarrier); - for (CheckpointableInput input : channelState.getInputs()) { - input.checkpointStarted(unalignedBarrier); - } + channelState.onCheckpointStartedForAllInputs(unalignedBarrier); controller.triggerGlobalCheckpoint(unalignedBarrier); if (controller.allBarriersReceived()) { for (CheckpointableInput input : channelState.getInputs()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java index 8f6bb211d2bf3..3f289895ac34d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java @@ -18,9 +18,14 @@ package org.apache.flink.streaming.runtime.io.checkpointing; +import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; +import org.apache.flink.util.ExceptionUtils; import java.io.IOException; import java.util.HashMap; @@ -28,6 +33,7 @@ import java.util.Map; import java.util.Set; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; /** @@ -35,6 +41,7 @@ * and {@link AbstractAlternatingAlignedBarrierHandlerState}. */ final class ChannelState { + private final Map sequenceNumberInAnnouncedChannels = new HashMap<>(); @@ -47,8 +54,21 @@ final class ChannelState { private final CheckpointableInput[] inputs; + private final RecoveryCheckpointTrigger recoveryCheckpointTrigger; + + private final ChannelStateWriter channelStateWriter; + public ChannelState(CheckpointableInput[] inputs) { + this(inputs, RecoveryCheckpointTrigger.NO_OP, ChannelStateWriter.NO_OP); + } + + public ChannelState( + CheckpointableInput[] inputs, + RecoveryCheckpointTrigger recoveryCheckpointTrigger, + ChannelStateWriter channelStateWriter) { this.inputs = inputs; + this.recoveryCheckpointTrigger = checkNotNull(recoveryCheckpointTrigger); + this.channelStateWriter = checkNotNull(channelStateWriter); } public void blockChannel(InputChannelInfo channelInfo) { @@ -98,4 +118,34 @@ public ChannelState emptyState() { sequenceNumberInAnnouncedChannels.clear(); return this; } + + /** + * Transfers spill-snapshot ownership to the writer after all inputs observe checkpoint start. + */ + public void onCheckpointStartedForAllInputs(CheckpointBarrier barrier) + throws CheckpointException, IOException { + long cpId = barrier.getId(); + FetchedChannelStateReader snap = null; + try { + snap = recoveryCheckpointTrigger.snapshotAndInsertBarriers(cpId); + + for (CheckpointableInput input : inputs) { + input.checkpointStarted(barrier); + } + + channelStateWriter.addInputDataFromSpill(cpId, snap); + } catch (Throwable t) { + if (snap != null) { + try { + snap.close(); + } catch (Exception suppressed) { + t.addSuppressed(suppressed); + } + } + if (t instanceof CheckpointException) { + throw (CheckpointException) t; + } + ExceptionUtils.rethrowIOException(t); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/InputProcessorUtil.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/InputProcessorUtil.java index 6d8a0268dadec..e299399e9936d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/InputProcessorUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/InputProcessorUtil.java @@ -22,6 +22,8 @@ import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.execution.CheckpointingMode; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; @@ -88,6 +90,30 @@ public static CheckpointBarrierHandler createCheckpointBarrierHandler( List> sourceInputs, MailboxExecutor mailboxExecutor, TimerService timerService) { + return createCheckpointBarrierHandler( + toNotifyOnCheckpoint, + jobConf, + config, + checkpointCoordinator, + taskName, + inputGates, + sourceInputs, + mailboxExecutor, + timerService, + RecoveryCheckpointTrigger.NO_OP); + } + + public static CheckpointBarrierHandler createCheckpointBarrierHandler( + CheckpointableTask toNotifyOnCheckpoint, + Configuration jobConf, + StreamConfig config, + SubtaskCheckpointCoordinator checkpointCoordinator, + String taskName, + List[] inputGates, + List> sourceInputs, + MailboxExecutor mailboxExecutor, + TimerService timerService, + RecoveryCheckpointTrigger recoveryCheckpointTrigger) { CheckpointableInput[] inputs = Stream.concat( @@ -98,6 +124,7 @@ public static CheckpointBarrierHandler createCheckpointBarrierHandler( Clock clock = SystemClock.getInstance(); CheckpointingMode checkpointingMode = CheckpointingOptions.getCheckpointingMode(jobConf); + ChannelStateWriter channelStateWriter = checkpointCoordinator.getChannelStateWriter(); switch (checkpointingMode) { case EXACTLY_ONCE: int numberOfChannels = @@ -115,7 +142,9 @@ public static CheckpointBarrierHandler createCheckpointBarrierHandler( timerService, inputs, clock, - numberOfChannels); + numberOfChannels, + recoveryCheckpointTrigger, + channelStateWriter); case AT_LEAST_ONCE: if (CheckpointingOptions.isUnalignedCheckpointEnabled(jobConf)) { throw new IllegalStateException( @@ -148,7 +177,9 @@ private static SingleCheckpointBarrierHandler createBarrierHandler( TimerService timerService, CheckpointableInput[] inputs, Clock clock, - int numberOfChannels) { + int numberOfChannels, + RecoveryCheckpointTrigger recoveryCheckpointTrigger, + ChannelStateWriter channelStateWriter) { boolean enableCheckpointAfterTasksFinished = config.getConfiguration() .get(CheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH); @@ -161,6 +192,8 @@ private static SingleCheckpointBarrierHandler createBarrierHandler( numberOfChannels, BarrierAlignmentUtil.createRegisterTimerCallback(mailboxExecutor, timerService), enableCheckpointAfterTasksFinished, + recoveryCheckpointTrigger, + channelStateWriter, inputs); } else { return SingleCheckpointBarrierHandler.aligned( diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/SingleCheckpointBarrierHandler.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/SingleCheckpointBarrierHandler.java index c1bd9ad6c8561..547941c24474d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/SingleCheckpointBarrierHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/SingleCheckpointBarrierHandler.java @@ -22,7 +22,9 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; @@ -119,6 +121,8 @@ public static SingleCheckpointBarrierHandler createUnalignedCheckpointBarrierHan "Strictly unaligned checkpoints should never register any callbacks"); }, enableCheckpointsAfterTasksFinish, + RecoveryCheckpointTrigger.NO_OP, + ChannelStateWriter.NO_OP, inputs); } @@ -130,6 +134,8 @@ public static SingleCheckpointBarrierHandler unaligned( int numOpenChannels, DelayableTimer registerTimer, boolean enableCheckpointAfterTasksFinished, + RecoveryCheckpointTrigger recoveryCheckpointTrigger, + ChannelStateWriter channelStateWriter, CheckpointableInput... inputs) { return new SingleCheckpointBarrierHandler( taskName, @@ -137,7 +143,9 @@ public static SingleCheckpointBarrierHandler unaligned( checkpointCoordinator, clock, numOpenChannels, - new AlternatingWaitingForFirstBarrierUnaligned(false, new ChannelState(inputs)), + new AlternatingWaitingForFirstBarrierUnaligned( + false, + new ChannelState(inputs, recoveryCheckpointTrigger, channelStateWriter)), false, registerTimer, inputs, @@ -173,6 +181,8 @@ public static SingleCheckpointBarrierHandler alternating( int numOpenChannels, DelayableTimer registerTimer, boolean enableCheckpointAfterTasksFinished, + RecoveryCheckpointTrigger recoveryCheckpointTrigger, + ChannelStateWriter channelStateWriter, CheckpointableInput... inputs) { return new SingleCheckpointBarrierHandler( taskName, @@ -180,7 +190,8 @@ public static SingleCheckpointBarrierHandler alternating( checkpointCoordinator, clock, numOpenChannels, - new AlternatingWaitingForFirstBarrier(new ChannelState(inputs)), + new AlternatingWaitingForFirstBarrier( + new ChannelState(inputs, recoveryCheckpointTrigger, channelStateWriter)), true, registerTimer, inputs, diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java index 6e45024e1ca55..e27876eccd749 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java @@ -149,7 +149,8 @@ protected void createInputProcessor( inputGates, operatorChain.getSourceTaskInputs(), mainMailboxExecutor, - timerService); + timerService, + getRecoveryCheckpointTrigger()); CheckpointedInputGate[] checkpointedInputGates = InputProcessorUtil.createCheckpointedMultipleInputGate( diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java index 009f48b082f75..eb772462f1f92 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java @@ -174,7 +174,8 @@ private CheckpointedInputGate createCheckpointedInputGate() { new List[] {Arrays.asList(inputGates)}, Collections.emptyList(), mainMailboxExecutor, - systemTimerService); + systemTimerService, + getRecoveryCheckpointTrigger()); CheckpointedInputGate[] checkpointedInputGates = InputProcessorUtil.createCheckpointedMultipleInputGate( diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java index f933d5069d1e0..5ea0f54428507 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java @@ -72,7 +72,8 @@ protected void createInputProcessor( new List[] {inputGates1, inputGates2}, Collections.emptyList(), mainMailboxExecutor, - systemTimerService); + systemTimerService, + getRecoveryCheckpointTrigger()); CheckpointedInputGate[] checkpointedInputGates = InputProcessorUtil.createCheckpointedMultipleInputGate( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplAddInputDataFromSpillTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplAddInputDataFromSpillTest.java new file mode 100644 index 0000000000000..22fbe3d7d5efb --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplAddInputDataFromSpillTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.jobgraph.JobVertexID; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +class ChannelStateWriterImplAddInputDataFromSpillTest { + + private static final JobID JOB_ID = new JobID(); + private static final JobVertexID JOB_VERTEX_ID = new JobVertexID(); + private static final int SUBTASK_INDEX = 0; + private static final long CHECKPOINT_ID = 7L; + private static final String TASK_NAME = "test"; + + @TempDir Path tempDir; + + @Test + void testNonEmptySnapshotAsyncDemux() throws Exception { + SyncChannelStateWriteRequestExecutor worker = + new SyncChannelStateWriteRequestExecutor(JOB_ID); + try (ChannelStateWriterImpl writer = newWriter(worker)) { + worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX); + writer.start(CHECKPOINT_ID, CheckpointOptions.forCheckpointWithDefaultLocation()); + + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + FetchedChannelState state; + try (TestSpillWriter spillWriter = new TestSpillWriter(tempDir)) { + spillWriter.writeRecord(c0, new byte[] {1, 2, 3}, 3); + spillWriter.writeRecord(c1, new byte[] {4, 5}, 2); + spillWriter.writeRecord(c0, new byte[] {6}, 1); + state = spillWriter.getChannelState(); + } + FetchedChannelStateReader reader = state.reader(); + // Drop the handoff grant; the reader now holds the only outstanding grant. + state.release(); + + writer.addInputDataFromSpill(CHECKPOINT_ID, reader); + // Request is queued but not yet processed — state must still be alive. + assertThat(state.isClosed()).isFalse(); + + worker.processAllRequests(); + // After processing, the reader is closed by the request, releasing the last grant. + assertThat(state.isClosed()).isTrue(); + } + } + + @Test + void testEmptySnapshotStillSubmitted() throws Exception { + // Empty readers are no longer short-circuited; they are submitted to the writer thread. + QueueCountingExecutor worker = new QueueCountingExecutor(JOB_ID); + try (ChannelStateWriterImpl writer = + new ChannelStateWriterImpl( + JOB_VERTEX_ID, + TASK_NAME, + SUBTASK_INDEX, + new ConcurrentHashMap<>(), + worker, + 5)) { + worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX); + writer.start(CHECKPOINT_ID, CheckpointOptions.forCheckpointWithDefaultLocation()); + + int submittedBefore = worker.submitCount.get(); + FetchedChannelState emptyState = + new FetchedChannelState(java.util.Collections.emptyList()); + FetchedChannelStateReader emptyReader = emptyState.reader(); + emptyState.release(); + + writer.addInputDataFromSpill(CHECKPOINT_ID, emptyReader); + + assertThat(worker.submitCount.get()) + .as("empty reader must still be submitted to the writer thread") + .isGreaterThan(submittedBefore); + } + } + + @Test + void testSegmentsClosedOnSuccess() throws Exception { + SyncChannelStateWriteRequestExecutor worker = + new SyncChannelStateWriteRequestExecutor(JOB_ID); + try (ChannelStateWriterImpl writer = newWriter(worker)) { + worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX); + writer.start(CHECKPOINT_ID, CheckpointOptions.forCheckpointWithDefaultLocation()); + + FetchedChannelState state; + try (TestSpillWriter spillWriter = new TestSpillWriter(tempDir)) { + spillWriter.writeRecord(new InputChannelInfo(0, 0), new byte[] {1}, 1); + state = spillWriter.getChannelState(); + } + FetchedChannelStateReader reader = state.reader(); + state.release(); + + writer.addInputDataFromSpill(CHECKPOINT_ID, reader); + worker.processAllRequests(); + + // After processing, the last grant is released and the state is cleaned up. + assertThat(state.isClosed()).isTrue(); + } + } + + // ------------------------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------------------------- + + private ChannelStateWriterImpl newWriter(SyncChannelStateWriteRequestExecutor worker) { + return new ChannelStateWriterImpl( + JOB_VERTEX_ID, TASK_NAME, SUBTASK_INDEX, new ConcurrentHashMap<>(), worker, 5); + } + + // ------------------------------------------------------------------------------------------- + // Executor stubs + // ------------------------------------------------------------------------------------------- + + private static final class QueueCountingExecutor implements ChannelStateWriteRequestExecutor { + + final AtomicInteger submitCount = new AtomicInteger(0); + + QueueCountingExecutor(JobID jobID) {} + + @Override + public void submit(ChannelStateWriteRequest e) { + submitCount.incrementAndGet(); + } + + @Override + public void submitPriority(ChannelStateWriteRequest e) { + submitCount.incrementAndGet(); + } + + @Override + public void start() throws IllegalStateException {} + + @Override + public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {} + + @Override + public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {} + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java index c77208f3ff749..bdff6d44718f1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java @@ -74,6 +74,16 @@ public void addInputData( } } + @Override + public void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader) { + checkCheckpointId(checkpointId); + try { + reader.close(); + } catch (Exception e) { + rethrow(e); + } + } + @Override public void addOutputData( long checkpointId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RescaleFilterLargeRecordOOMRegressionITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RescaleFilterLargeRecordOOMRegressionITCase.java new file mode 100644 index 0000000000000..9078b682bb2b2 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RescaleFilterLargeRecordOOMRegressionITCase.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Regression coverage for the heap-blowup scenario produced by rescale + filter + a large recovered + * record: a workload whose total recovered bytes greatly exceed any single buffer must spill to + * disk rather than pin the bytes on the task heap. + * + *

A full {@code MiniCluster} reproduction would need a tuned heap and a stateful rescale job + * graph, which is too heavy for a unit-style test. This ITCase asserts the memory-bound invariant + * directly: bytes land on disk in bounded segments and the writer keeps no per-record heap + * allocation. + * + *

Segment boundaries are self-described in the 12-byte disk headers + * ([gateIdx][channelIdx][bufferLength]); no in-memory segment locator table is maintained. + */ +class RescaleFilterLargeRecordOOMRegressionITCase { + + @TempDir Path tempDir; + + /** + * Verifies that large records for a single channel land on disk rather than remaining on the + * heap. The disk-landing invariant is confirmed by: + * + *

    + *
  1. Reading back the data via the reader and confirming byte count and content. + *
  2. The spill file's physical size exceeds the rotation cap (segment not split). + *
+ * + *

File count is intentionally not asserted here: a single channel produces one segment, and + * a segment is never split across files. With all writes going to channel (0,0), there is + * exactly one segment, which resides entirely in one file regardless of its size. That is the + * correct behavior per the "segment does not cross files" design invariant. + */ + @Test + void testLargeRecordsLandOnDiskNotHeap() throws Exception { + // Total bytes greatly exceed a single segment size to exercise the spill path. + final long segmentSize = 4L * 1024 * 1024; + final int largeRecordSize = 256 * 1024; + final int recordCount = 64; + // Each writeRecord call writes a 4-byte length prefix plus the record bytes. + final long expectedBodyBytes = (long) recordCount * (Integer.BYTES + largeRecordSize); + assertThat(expectedBodyBytes) + .as("workload exceeds the segment-size cap") + .isGreaterThan(segmentSize); + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir, segmentSize)) { + InputChannelInfo channelInfo = new InputChannelInfo(0, 0); + byte[] reusableRecord = new byte[largeRecordSize]; + for (int i = 0; i < reusableRecord.length; i++) { + reusableRecord[i] = (byte) (i & 0xff); + } + for (int i = 0; i < recordCount; i++) { + // The reusable byte array is mutated and reused across writes — the writer must + // flush through to disk without retaining a reference, otherwise per-record heap + // pressure would accumulate. + reusableRecord[0] = (byte) i; + writer.writeRecord(channelInfo, reusableRecord, reusableRecord.length); + } + state = writer.getChannelState(); + } + + // Verify data is physically on disk and can be read back correctly. + // The reader holds a lifecycle grant; closing it triggers file deletion. + long totalReadBytes = 0L; + try (FetchedChannelStateReader reader = state.reader()) { + // One channel => one segment => one file (segment not split across files). + assertThat(state.files()) + .as("single channel produces exactly one spill file") + .hasSize(1); + long physicalFileSize = Files.size(state.files().get(0)); + // File size = 12B header + expectedBodyBytes + assertThat(physicalFileSize) + .as("physical file size exceeds the rotation cap (expected: segment not split)") + .isGreaterThan(segmentSize); + + Optional next; + while ((next = reader.nextSegment()).isPresent()) { + SpillSegment seg = next.get(); + try (InputStream body = seg.bodyStream()) { + byte[] buf = new byte[4096]; + int read; + while ((read = body.read(buf)) != -1) { + totalReadBytes += read; + } + } + seg.commit(); + } + } + assertThat(totalReadBytes) + .as("all written body bytes must be readable from disk") + .isEqualTo(expectedBodyBytes); + } + + /** + * Verifies that multiple channels interleaved across writes do trigger file rotation, producing + * more than one file when the total written bytes exceed the rotation cap. This is the "segment + * switch triggers rotation" path complementing the single-channel test above. + */ + @Test + void testMultiChannelWritesTriggerFileRotation() throws IOException { + // Small cap so that alternating two channels quickly crosses the threshold. + final long segmentSize = 512L * 1024; + final int recordSize = 64 * 1024; + final int roundCount = 10; // each round writes one record per channel + + FetchedChannelState state; + try (TestSpillWriter writer = new TestSpillWriter(tempDir, segmentSize)) { + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + byte[] record = new byte[recordSize]; + for (int round = 0; round < roundCount; round++) { + writer.writeRecord(c0, record, record.length); + writer.writeRecord(c1, record, record.length); + } + state = writer.getChannelState(); + } + + // Alternating channels produce multiple segments (one file per seal above the cap). + assertThat(state.files()) + .as("total size exceeding the cap must trigger file rotation into multiple files") + .hasSizeGreaterThanOrEqualTo(2); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/UnalignedCheckpointDuringRecoveryITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/UnalignedCheckpointDuringRecoveryITCase.java new file mode 100644 index 0000000000000..b073f16946d14 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/UnalignedCheckpointDuringRecoveryITCase.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration coverage for the recovery-checkpoint dispatcher. End-to-end rescaling coverage + * against a real {@code MiniCluster} lives in {@code UnalignedCheckpointRescaleITCase}; this class + * pins the disjoint-and-complete invariant that the recovery-time checkpoint slice must satisfy, + * using a unit-style fixture that mirrors the recovery-checkpoint slice produced by the production + * drain/snapshot path. + */ +class UnalignedCheckpointDuringRecoveryITCase { + + @Test + void testStep1SnapshotPlusStep2PreBarrierBytesEqualOriginal() { + // Fixture: the recovery filter wrote a sequence of buffers to disk; the drain has + // delivered entries 0..2 into channel queues, while entries 3..6 remain on disk. The + // dispatcher snapshots the on-disk slice and inserts barriers into channel queues, so + // the in-channel pre-barrier portion plus the on-disk slice together must cover every + // original byte exactly once. + InputChannelInfo c0 = new InputChannelInfo(0, 0); + InputChannelInfo c1 = new InputChannelInfo(0, 1); + + List originalSeq = + Arrays.asList( + new RecoveredEntry(c0, new byte[] {1, 2}), + new RecoveredEntry(c1, new byte[] {3}), + new RecoveredEntry(c0, new byte[] {4, 5, 6}), + new RecoveredEntry(c1, new byte[] {7, 8}), + new RecoveredEntry(c0, new byte[] {9}), + new RecoveredEntry(c0, new byte[] {10, 11}), + new RecoveredEntry(c1, new byte[] {12})); + + // First three entries are still in channel queues (the dispatcher's per-input walk + // covers them); the remaining four sit on disk (the writer's async demux covers them). + List step2Sources = originalSeq.subList(0, 3); + List step3Sources = originalSeq.subList(3, originalSeq.size()); + + Map persistedByChannel = new HashMap<>(); + for (RecoveredEntry entry : step2Sources) { + persistedByChannel.merge(entry.channelInfo, entry.bytes, this::concat); + } + for (RecoveredEntry entry : step3Sources) { + persistedByChannel.merge(entry.channelInfo, entry.bytes, this::concat); + } + + // Persisted per-channel bytes must equal the concatenation of the original sequence, + // regardless of which source produced each byte — no duplication, no gaps. + Map expected = new HashMap<>(); + for (RecoveredEntry entry : originalSeq) { + expected.merge(entry.channelInfo, entry.bytes, this::concat); + } + assertThat(persistedByChannel.keySet()).isEqualTo(expected.keySet()); + for (InputChannelInfo info : expected.keySet()) { + assertThat(persistedByChannel.get(info)).isEqualTo(expected.get(info)); + } + } + + @Test + void testEmptyDiskSnapshotReaderCloseIsClean() throws Exception { + // An empty reader (no spill files) must be closeable without error. This guards the + // fixture against silently swallowing close() failures on the zero-segment path. + FetchedChannelStateReader emptyReader = FetchedChannelStateReader.emptyReader(); + assertThat(emptyReader.nextSegment()).isEmpty(); + emptyReader.close(); + } + + private byte[] concat(byte[] a, byte[] b) { + byte[] out = new byte[a.length + b.length]; + System.arraycopy(a, 0, out, 0, a.length); + System.arraycopy(b, 0, out, a.length, b.length); + return out; + } + + private static final class RecoveredEntry { + final InputChannelInfo channelInfo; + final byte[] bytes; + + RecoveredEntry(InputChannelInfo channelInfo, byte[] bytes) { + this.channelInfo = channelInfo; + this.bytes = bytes; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriersDispatchHookTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriersDispatchHookTest.java new file mode 100644 index 0000000000000..dfd865172882f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriersDispatchHookTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.checkpointing; + +import org.junit.jupiter.api.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Source-level invariants for the {@code alignedCheckpointTimeout} → UC switch in {@link + * AlternatingCollectingBarriers}. Note this is NOT a {@code barrierReceived} hook — the parent + * {@code AbstractAlternatingAlignedBarrierHandlerState.barrierReceived} is {@code final}; the only + * place to switch from aligned to UC is {@code alignedCheckpointTimeout}. + */ +class AlternatingCollectingBarriersDispatchHookTest { + + @Test + void testDispatcherIsCalledBeforeTriggerGlobalCheckpoint() throws Exception { + String source = readSource(); + + int initIdx = source.indexOf("controller.initInputsCheckpoint(unalignedBarrier)"); + int dispatchIdx = source.indexOf("onCheckpointStartedForAllInputs(unalignedBarrier)"); + int triggerIdx = source.indexOf("controller.triggerGlobalCheckpoint(unalignedBarrier)"); + + assertThat(initIdx).as("initInputsCheckpoint call exists").isNotNegative(); + assertThat(dispatchIdx).as("dispatcher call exists").isNotNegative(); + assertThat(triggerIdx).as("triggerGlobalCheckpoint call exists").isNotNegative(); + + assertThat(initIdx) + .as( + "initInputsCheckpoint precedes dispatcher (cpId result registered before " + + "Step 3)") + .isLessThan(dispatchIdx); + assertThat(dispatchIdx) + .as("dispatcher precedes triggerGlobalCheckpoint (master ordering)") + .isLessThan(triggerIdx); + } + + @Test + void testHookIsInAlignedCheckpointTimeoutNotBarrierReceived() throws Exception { + String source = readSource(); + + // Hard guard: AbstractAlternatingAlignedBarrierHandlerState.barrierReceived is final, so + // AlternatingCollectingBarriers cannot override it. The UC switch lives in + // alignedCheckpointTimeout only. + assertThat(source) + .as("alignedCheckpointTimeout is the UC switch entry on this state class") + .contains("public BarrierHandlerState alignedCheckpointTimeout("); + assertThat(source) + .as("No barrierReceived override should exist on AlternatingCollectingBarriers") + .doesNotContain("public BarrierHandlerState barrierReceived("); + + int timeoutIdx = source.indexOf("alignedCheckpointTimeout("); + int dispatchIdx = source.indexOf("onCheckpointStartedForAllInputs(", timeoutIdx); + assertThat(dispatchIdx) + .as("dispatcher call is inside alignedCheckpointTimeout body") + .isNotNegative(); + } + + @Test + void testNoLegacyPerInputCheckpointStartedLoopInTimeout() throws Exception { + String source = readSource(); + + int methodStart = source.indexOf("public BarrierHandlerState alignedCheckpointTimeout("); + assertThat(methodStart).isNotNegative(); + int methodEnd = findMethodEnd(source, methodStart); + String body = source.substring(methodStart, methodEnd); + + assertThat(body) + .as( + "Pre-trigger per-input checkpointStarted loop should be removed; dispatcher " + + "now owns the fan-out") + .doesNotContain("input.checkpointStarted(unalignedBarrier)"); + } + + private static String readSource() throws Exception { + Path candidate = + Paths.get( + "src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java"); + if (!Files.exists(candidate)) { + candidate = + Paths.get( + "flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.java"); + } + return new String(Files.readAllBytes(candidate)); + } + + private static int findMethodEnd(String source, int start) { + int firstBrace = source.indexOf('{', start); + int depth = 1; + for (int i = firstBrace + 1; i < source.length(); i++) { + char c = source.charAt(i); + if (c == '{') { + depth++; + } else if (c == '}') { + depth--; + if (depth == 0) { + return i + 1; + } + } + } + return source.length(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest.java new file mode 100644 index 0000000000000..036e5ad70c968 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.checkpointing; + +import org.junit.jupiter.api.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Source-level invariants for the recovery-checkpoint dispatcher hook in {@link + * AlternatingWaitingForFirstBarrierUnaligned}: + * + *

    + *
  1. The dispatcher call exists. + *
  2. It runs after {@code controller.initInputsCheckpoint} (so the cpId's {@code + * ChannelStateWriteResult} is registered before the writer is invoked) and before {@code + * controller.triggerGlobalCheckpoint}. + *
  3. No vestigial per-input {@code checkpointStarted} loop remains alongside the dispatcher. + *
+ */ +class AlternatingWaitingForFirstBarrierUnalignedDispatchHookTest { + + @Test + void testDispatcherIsCalledBeforeTriggerGlobalCheckpoint() throws Exception { + String source = readSource(); + + int initIdx = source.indexOf("controller.initInputsCheckpoint(unalignedBarrier)"); + int dispatchIdx = source.indexOf("onCheckpointStartedForAllInputs(unalignedBarrier)"); + int triggerIdx = source.indexOf("controller.triggerGlobalCheckpoint(unalignedBarrier)"); + + assertThat(initIdx).as("initInputsCheckpoint call exists").isNotNegative(); + assertThat(dispatchIdx).as("dispatcher call exists").isNotNegative(); + assertThat(triggerIdx).as("triggerGlobalCheckpoint call exists").isNotNegative(); + + assertThat(initIdx) + .as( + "initInputsCheckpoint must precede the dispatcher so the cpId result is " + + "registered when Step 3 fires") + .isLessThan(dispatchIdx); + assertThat(dispatchIdx) + .as("dispatcher must precede triggerGlobalCheckpoint (master ordering)") + .isLessThan(triggerIdx); + } + + @Test + void testNoLegacyPerInputCheckpointStartedLoopInBarrierReceived() throws Exception { + String source = readSource(); + + // barrierReceived still iterates input.checkpointStopped in the allBarriersReceived + // branch — that loop is fine. We only forbid a parallel input.checkpointStarted loop, + // which the dispatcher now owns. + int methodStart = source.indexOf("public BarrierHandlerState barrierReceived("); + assertThat(methodStart).isNotNegative(); + + int methodEnd = findMethodEnd(source, methodStart); + String body = source.substring(methodStart, methodEnd); + + assertThat(body) + .as( + "Pre-trigger per-input checkpointStarted loop should be removed; dispatcher " + + "now owns the fan-out") + .doesNotContain("input.checkpointStarted(unalignedBarrier)"); + } + + private static String readSource() throws Exception { + Path candidate = + Paths.get( + "src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java"); + if (!Files.exists(candidate)) { + candidate = + Paths.get( + "flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.java"); + } + return new String(Files.readAllBytes(candidate)); + } + + private static int findMethodEnd(String source, int start) { + int firstBrace = source.indexOf('{', start); + int depth = 1; + for (int i = firstBrace + 1; i < source.length(); i++) { + char c = source.charAt(i); + if (c == '{') { + depth++; + } else if (c == '}') { + depth--; + if (depth == 0) { + return i + 1; + } + } + } + return source.length(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelStateDispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelStateDispatcherTest.java new file mode 100644 index 0000000000000..53c355727fc2b --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelStateDispatcherTest.java @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.checkpointing; + +import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.CheckpointType; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader; +import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; +import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo; +import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.CheckpointableInput; +import org.apache.flink.runtime.state.CheckpointStorageLocationReference; +import org.apache.flink.util.CloseableIterator; + +import org.junit.jupiter.api.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Verifies the {@link ChannelState#onCheckpointStartedForAllInputs} dispatcher: call ordering, + * feature-off no-op routing through the {@link RecoveryCheckpointTrigger#NO_OP} singleton, and + * absence of an outer feature-flag branch. + */ +class ChannelStateDispatcherTest { + + private static final long CHECKPOINT_ID = 7L; + + @Test + void testStepOrderingFeatureOn() throws Exception { + List trace = new ArrayList<>(); + // An empty reader is sufficient to verify ordering. + FetchedChannelStateReader snap = FetchedChannelStateReader.emptyReader(); + RecordingTrigger trigger = new RecordingTrigger(trace, snap); + RecordingWriter writer = new RecordingWriter(trace); + CheckpointableInput input1 = new RecordingInput(trace, "in1"); + CheckpointableInput input2 = new RecordingInput(trace, "in2"); + + ChannelState state = + new ChannelState(new CheckpointableInput[] {input1, input2}, trigger, writer); + + CheckpointBarrier barrier = newUnalignedBarrier(); + state.onCheckpointStartedForAllInputs(barrier); + + assertThat(trace) + .containsExactly( + "trigger.snapshotAndInsertBarriers:" + CHECKPOINT_ID, + "in1.checkpointStarted:" + CHECKPOINT_ID, + "in2.checkpointStarted:" + CHECKPOINT_ID, + "writer.addInputDataFromSpill:" + CHECKPOINT_ID); + } + + @Test + void testStepOrderingFeatureOff() throws Exception { + List trace = new ArrayList<>(); + RecordingWriter writer = new RecordingWriter(trace); + CheckpointableInput input = new RecordingInput(trace, "in1"); + + ChannelState state = + new ChannelState( + new CheckpointableInput[] {input}, RecoveryCheckpointTrigger.NO_OP, writer); + + state.onCheckpointStartedForAllInputs(newUnalignedBarrier()); + + assertThat(trace) + .containsExactly( + "in1.checkpointStarted:" + CHECKPOINT_ID, + "writer.addInputDataFromSpill:" + CHECKPOINT_ID); + assertThat(writer.lastSnapshotWasEmpty.get()).isTrue(); + } + + @Test + void testEmptySnapshotStillSubmitted() throws Exception { + // Empty readers (no spill files) are no longer short-circuited; they still reach + // addInputDataFromSpill on the writer thread. + List trace = new ArrayList<>(); + FetchedChannelStateReader emptySnap = FetchedChannelStateReader.emptyReader(); + RecordingTrigger trigger = new RecordingTrigger(trace, emptySnap); + RecordingWriter writer = new RecordingWriter(trace); + + ChannelState state = + new ChannelState( + new CheckpointableInput[] {new RecordingInput(trace, "in1")}, + trigger, + writer); + + state.onCheckpointStartedForAllInputs(newUnalignedBarrier()); + + // Empty reader must still reach the writer (no inline short-circuit). + assertThat(writer.addInputDataFromSpillCalls.get()).isEqualTo(1); + assertThat(writer.lastSnapshotWasEmpty.get()).isTrue(); + } + + @Test + void testNoIfFilterOnInDispatcher() throws Exception { + // Branch-free routing through the null-object trigger is a hard correctness invariant; + // a feature-flag check inside the dispatcher would silently bypass it. Guard against + // that by scanning the dispatcher source for "filter" / "feature". + Path candidate = + Paths.get( + "src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java"); + if (!Files.exists(candidate)) { + candidate = + Paths.get( + "flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/ChannelState.java"); + } + assertThat(Files.exists(candidate)) + .as("Located ChannelState.java for the source-level invariant check") + .isTrue(); + + String all = new String(Files.readAllBytes(candidate)); + int methodStart = all.indexOf("public void onCheckpointStartedForAllInputs"); + assertThat(methodStart).isNotNegative(); + int methodEnd = all.indexOf(" private void", methodStart); + String body = all.substring(methodStart, methodEnd > 0 ? methodEnd : all.length()); + + StringBuilder code = new StringBuilder(); + for (String line : body.split("\n")) { + int idx = line.indexOf("//"); + code.append(idx >= 0 ? line.substring(0, idx) : line).append('\n'); + } + String codeOnly = code.toString(); + + assertThat(codeOnly) + .as("Dispatcher must not branch on filter / feature flags") + .doesNotContain("filter") + .doesNotContain("feature"); + } + + private static CheckpointBarrier newUnalignedBarrier() { + return new CheckpointBarrier( + CHECKPOINT_ID, + 1000L, + CheckpointOptions.unaligned( + CheckpointType.CHECKPOINT, + CheckpointStorageLocationReference.getDefault())); + } + + private static final class RecordingTrigger implements RecoveryCheckpointTrigger { + private final List trace; + private final FetchedChannelStateReader snapshot; + + RecordingTrigger(List trace, FetchedChannelStateReader snapshot) { + this.trace = trace; + this.snapshot = snapshot; + } + + @Override + public FetchedChannelStateReader snapshotAndInsertBarriers(long checkpointId) { + trace.add("trigger.snapshotAndInsertBarriers:" + checkpointId); + return snapshot; + } + } + + private static final class RecordingWriter implements ChannelStateWriter { + private final List trace; + final AtomicBoolean lastSnapshotWasEmpty = new AtomicBoolean(false); + final AtomicLong lastCpId = new AtomicLong(-1L); + final AtomicInteger addInputDataFromSpillCalls = new AtomicInteger(0); + + RecordingWriter(List trace) { + this.trace = trace; + } + + @Override + public void start(long checkpointId, CheckpointOptions checkpointOptions) {} + + @Override + public void addInputData( + long checkpointId, + InputChannelInfo info, + int startSeqNum, + CloseableIterator data) {} + + @Override + public void addOutputData( + long checkpointId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) {} + + @Override + public void addOutputDataFuture( + long checkpointId, + ResultSubpartitionInfo info, + int startSeqNum, + CompletableFuture> data) {} + + @Override + public void finishInput(long checkpointId) {} + + @Override + public void finishOutput(long checkpointId) {} + + @Override + public void abort(long checkpointId, Throwable cause, boolean cleanup) {} + + @Override + public ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) { + return ChannelStateWriteResult.EMPTY; + } + + @Override + public void addInputDataFromSpill(long checkpointId, FetchedChannelStateReader reader) { + trace.add("writer.addInputDataFromSpill:" + checkpointId); + lastCpId.set(checkpointId); + addInputDataFromSpillCalls.incrementAndGet(); + try { + // Peek whether the reader has any segments by attempting the first advance. + // The first nextSegment() call is exempt from the "previous body consumed" rule. + lastSnapshotWasEmpty.set(reader.nextSegment().isEmpty()); + reader.close(); + } catch (Exception ignored) { + } + } + + @Override + public void close() {} + } + + private static final class RecordingInput implements CheckpointableInput { + + private final List trace; + private final String name; + + RecordingInput(List trace, String name) { + this.trace = trace; + this.name = name; + } + + @Override + public void blockConsumption(InputChannelInfo channelInfo) {} + + @Override + public void resumeConsumption(InputChannelInfo channelInfo) {} + + @Override + public List getChannelInfos() { + return Collections.emptyList(); + } + + @Override + public int getNumberOfInputChannels() { + return 0; + } + + @Override + public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { + trace.add(name + ".checkpointStarted:" + barrier.getId()); + } + + @Override + public void checkpointStopped(long cancelledCheckpointId) {} + + @Override + public int getInputGateIndex() { + return 0; + } + + @Override + public void convertToPriorityEvent(int channelIndex, int sequenceNumber) {} + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/TestBarrierHandlerFactory.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/TestBarrierHandlerFactory.java index 1b345ebe0d2a0..8c281c04a6b13 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/TestBarrierHandlerFactory.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/TestBarrierHandlerFactory.java @@ -20,6 +20,7 @@ import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter; +import org.apache.flink.runtime.checkpoint.channel.RecoveryCheckpointTrigger; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.streaming.runtime.tasks.TestSubtaskCheckpointCoordinator; @@ -72,6 +73,8 @@ public SingleCheckpointBarrierHandler create( inputGate.getNumberOfInputChannels(), actionRegistration, enableCheckpointsAfterTasksFinish, + RecoveryCheckpointTrigger.NO_OP, + stateWriter, inputGate); } } From be9cfbf9be2b9834d686c0b04c2aade48e45db5c Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Fri, 26 Jun 2026 19:05:38 +0200 Subject: [PATCH 07/16] [FLINK-38544] Remove stale recovery comment blocks Delete the verbose explanatory comments describing the old channel-state recovery wiring (the recoverySetupCompleteFuture javadoc, the allOf-vs-thenRun race essay, the setCheckpointingDuringRecoveryEnabled note). These annotate code that the subsequent async-recovery rewrite replaces; removing them up front keeps the rewrite diff free of comment-deletion churn. --- .../streaming/runtime/tasks/StreamTask.java | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 7c30746ba4bf3..f1571ce7d4936 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -309,14 +309,6 @@ public abstract class StreamTask> /** TODO it might be replaced by the global IO executor on TaskManager level future. */ private final ExecutorService channelIOExecutor; - /** - * Completed (on the {@code channelIOExecutor}) once recovery setup finishes, carrying the - * resolved checkpoint trigger: the spill drainer when recovery carries channel state, otherwise - * {@link RecoveryCheckpointTrigger#NO_OP}. Two consumers ride on this single completion: the - * barrier handler, built before the drainer exists, holds the future and reads the trigger - * lazily via {@code getNow} once a checkpoint fires; and gate conversion waits on its - * completion to run {@code requestPartitions()} (buffer filtering is done by then). - */ private final CompletableFuture recoverySetupCompleteFuture = new CompletableFuture<>(); @@ -900,9 +892,6 @@ private CompletableFuture restoreStateAndGates( boolean checkpointingDuringRecoveryEnabled = CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); - // Must set the flag on input gates BEFORE starting the async read task, because - // finishReadRecoveredState() reads this flag to decide whether to enqueue the legacy - // end-of-state sentinel. for (IndexedInputGate inputGate : inputGates) { inputGate.setCheckpointingDuringRecoveryEnabled(checkpointingDuringRecoveryEnabled); } @@ -918,25 +907,11 @@ private CompletableFuture restoreStateAndGates( checkpointingDuringRecoveryEnabled, physicalChannelsFuture)); - // We wait for all input channel state to recover before we go into RUNNING state, and thus - // start checkpointing. If we implement incremental checkpointing of input channel state - // we must make sure it supports CheckpointType#FULL_CHECKPOINT. List> recoveredFutures = checkpointingDuringRecoveryEnabled ? wireGateConversionWithCheckpointing(inputGates, physicalChannelsFuture) : wireGateConversion(inputGates); - // Return allOf future instead of thenRun future. thenRun() returns a NEW future that - // completes only after the callback finishes. CompletableFuture executes thenRun callbacks - // synchronously on the thread that calls complete(). When recoveredFutures contains - // recoverySetupCompleteFuture (checkpointingDuringRecovery enabled), complete() is called - // on channelIOExecutor (in recoverChannelState), so thenRun(suspend) also runs on - // channelIOExecutor. suspend() sends a poison mail, and the mailbox thread can pick it up - // and exit runMailboxLoop() before the thenRun future completes — causing - // checkState(isDone) to fail. With stateConsumedFuture (the default), complete() runs on - // the mailbox thread itself, so thenRun(suspend) blocks the loop from processing the poison - // mail until the future completes — no race. Returning allOf future avoids the issue - // entirely. CompletableFuture allRecoveredFuture = CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0])); allRecoveredFuture.thenRun(mailboxProcessor::suspend); From c6a95d966ee997160b68db6c6915afa0e6507aba Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Fri, 26 Jun 2026 19:07:05 +0200 Subject: [PATCH 08/16] [FLINK-38544] Asynchronous channel-state recovery Rewrite StreamTask channel-state recovery to run asynchronously on the channelIOExecutor: split restoreStateAndGates into recoverChannelsWithCheckpointing / recoverChannelsWithoutCheckpointing, threading the recovery checkpoint trigger and the fetched-state drainer through the new SequentialChannelStateReader / FetchedChannelState(Drainer) / RecoveryCheckpointTrigger interfaces. Also release channel state before returning in SequentialChannelStateReaderImpl. --- .../channel/FetchedChannelStateDrainer.java | 13 +- .../channel/RecoveryCheckpointTrigger.java | 9 + .../channel/SequentialChannelStateReader.java | 17 +- .../SequentialChannelStateReaderImpl.java | 16 +- ...reditBasedSequenceNumberingViewReader.java | 1 + .../streaming/runtime/tasks/StreamTask.java | 271 +++++++----------- .../ChannelIOExecutorDrainSubmissionTest.java | 4 +- ...hedChannelStateDrainerConcurrencyTest.java | 3 +- .../FetchedChannelStateDrainerTest.java | 3 +- .../SequentialChannelStateReaderImplTest.java | 6 + .../tasks/TaskCheckpointingBehaviourTest.java | 2 + 11 files changed, 147 insertions(+), 198 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java index ce58cd4cea8de..950a65e60b01e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java @@ -29,7 +29,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.CompletableFuture; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -46,9 +45,10 @@ public final class FetchedChannelStateDrainer implements RecoveryCheckpointTrigg private final FetchedChannelStateReader rootReader; - private final CompletableFuture resolvedChannelsFuture; + private final ResolvedChannels channels; private final Object lock = new Object(); + private final FetchedChannelState channelState; /** * Set under {@link #lock} once {@link #drain()} has consumed every segment. After that the @@ -59,10 +59,10 @@ public final class FetchedChannelStateDrainer implements RecoveryCheckpointTrigg private boolean drainFinished; public FetchedChannelStateDrainer( - FetchedChannelState channelState, - CompletableFuture> channelsFuture) { + FetchedChannelState channelState, List channels) { + this.channelState = channelState; this.rootReader = checkNotNull(channelState).reader(); - this.resolvedChannelsFuture = checkNotNull(channelsFuture).thenApply(ResolvedChannels::new); + this.channels = new ResolvedChannels(channels); } private static final class ResolvedChannels { @@ -90,7 +90,7 @@ private static final class ResolvedChannels { * pair is locked to guarantee atomicity with snapshot. */ public void drain() throws IOException, InterruptedException { - ResolvedChannels channels = resolvedChannelsFuture.join(); + channelState.release(); Optional next; while ((next = rootReader.nextSegment()).isPresent()) { SpillSegment seg = next.get(); @@ -177,7 +177,6 @@ private static int fill(Buffer buf, InputStream in, int remaining) throws IOExce @Override public FetchedChannelStateReader snapshotAndInsertBarriers(long checkpointId) throws IOException { - ResolvedChannels channels = resolvedChannelsFuture.join(); // Barrier insertion and snapshot must occur within the same critical section so that the // snapshot's committed position reflects exactly the drain position at the moment barriers diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java index 01053ad836358..d2f08736f856a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointTrigger.java @@ -33,4 +33,13 @@ public interface RecoveryCheckpointTrigger { /** Returns an empty reader (no spill files, so no segments) and inserts no barriers. */ RecoveryCheckpointTrigger NO_OP = checkpointId -> FetchedChannelStateReader.emptyReader(); + + RecoveryCheckpointTrigger NOT_READY = + ign -> { + throw new IllegalStateException("RecoveryCheckpointTrigger is not ready yet"); + }; + RecoveryCheckpointTrigger FAILING = + ign -> { + throw new IllegalStateException("Triggering checkpoints is not possible"); + }; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java index 7fea56b5ace06..88296c517b015 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java @@ -35,15 +35,13 @@ public interface SequentialChannelStateReader extends AutoCloseable { * @param inputGates The input gates to recover state for. * @param filterContext The filter context containing input configs and rescaling info. */ - void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) + Optional readInputData( + InputGate[] inputGates, RecordFilterContext filterContext) throws IOException, InterruptedException; void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) throws IOException, InterruptedException; - /** Returns the {@link FetchedChannelState} produced by {@link #readInputData}, if any. */ - Optional getProducedChannelState(); - @Override void close() throws Exception; @@ -51,18 +49,15 @@ void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCom new SequentialChannelStateReader() { @Override - public void readInputData( - InputGate[] inputGates, RecordFilterContext filterContext) {} + public Optional readInputData( + InputGate[] inputGates, RecordFilterContext filterContext) { + return Optional.empty(); + } @Override public void readOutputData( ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) {} - @Override - public Optional getProducedChannelState() { - return Optional.empty(); - } - @Override public void close() {} }; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java index 161968b6d4739..263335bc0c7c8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java @@ -65,7 +65,8 @@ public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) { } @Override - public void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) + public Optional readInputData( + InputGate[] inputGates, RecordFilterContext filterContext) throws IOException, InterruptedException { // Create filtering handler if filtering is needed @@ -85,7 +86,7 @@ public void readInputData(InputGate[] inputGates, RecordFilterContext filterCont filterContext.getMemorySegmentSize(), filterContext.getTmpDirectories()); try (ChannelStateFilteringHandler ignored = filteringHandler) { - try { + try (stateHandler) { read( stateHandler, groupByDelegate( @@ -102,18 +103,13 @@ public void readInputData(InputGate[] inputGates, RecordFilterContext filterCont !filteringHandler.hasPartialData(), "Not all data has been fully consumed during filtering"); } - } finally { - stateHandler.close(); } - this.producedChannelState = stateHandler.getProducedChannelState(); + // stateHandler.close() (above) has flushed the filter writer and published the + // produced spill file, so read getProducedChannelState() after the close completes. + return Optional.ofNullable(stateHandler.getProducedChannelState()); } } - @Override - public Optional getProducedChannelState() { - return Optional.ofNullable(producedChannelState); - } - @Override public void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) throws IOException, InterruptedException { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java index 1ac2687bf3946..c1e00eca802d9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java @@ -48,6 +48,7 @@ * non-emptiness, similar to the {@link LocalInputChannel}. */ class CreditBasedSequenceNumberingViewReader + // local input channel changes? implements BufferAvailabilityListener, NetworkSequenceViewReader { private final Object requestLock = new Object(); diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index f1571ce7d4936..5f6b9dcbe5e9a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -146,6 +146,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -164,6 +165,7 @@ import static org.apache.flink.util.ExceptionUtils.firstOrSuppressed; import static org.apache.flink.util.Preconditions.checkState; import static org.apache.flink.util.concurrent.FutureUtils.assertNoException; +import static org.apache.flink.util.concurrent.FutureUtils.completeAll; /** * Base class for all streaming tasks. A task is the unit of local processing that is deployed and @@ -309,8 +311,8 @@ public abstract class StreamTask> /** TODO it might be replaced by the global IO executor on TaskManager level future. */ private final ExecutorService channelIOExecutor; - private final CompletableFuture recoverySetupCompleteFuture = - new CompletableFuture<>(); + private RecoveryCheckpointTrigger recoveryCheckpointTrigger = + RecoveryCheckpointTrigger.NOT_READY; // ======================================================== // Final checkpoint / savepoint @@ -850,6 +852,12 @@ void restoreInternal() throws Exception { allGatesRecoveredFuture.isDone(), "Mailbox loop interrupted before recovery was finished."); + try { + allGatesRecoveredFuture.get(); + } catch (ExecutionException e) { + ExceptionUtils.rethrowException(e.getCause()); + } + // we recovered all the gates, we can close the channel IO executor as it is no longer // needed channelIOExecutor.shutdown(); @@ -889,180 +897,124 @@ private CompletableFuture restoreStateAndGates( INITIALIZE_STATE_DURATION, initializeStateEndTs - readOutputDataTs); IndexedInputGate[] inputGates = getEnvironment().getAllInputGates(); - boolean checkpointingDuringRecoveryEnabled = - CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); + CompletableFuture recoveryCompletionFuture = + CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()) + ? recoverChannelsWithCheckpointing(reader, inputGates) + : recoverChannelsWithoutCheckpointing(reader, inputGates); - for (IndexedInputGate inputGate : inputGates) { - inputGate.setCheckpointingDuringRecoveryEnabled(checkpointingDuringRecoveryEnabled); - } + recoveryCompletionFuture.whenComplete((ign, throwable) -> mailboxProcessor.suspend()); + return recoveryCompletionFuture; + } - final CompletableFuture> physicalChannelsFuture = - new CompletableFuture<>(); + private CompletableFuture recoverChannelsWithCheckpointing( + SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { + return setRecoveryCheckpointTrigger(RecoveryCheckpointTrigger.NOT_READY) + .thenApplyAsync(ign -> fetchChannelState(reader, inputGates), channelIOExecutor) + .thenCompose( + state -> + requestPartitions(inputGates, state.isPresent()) + .thenApply(channels -> buildDrainer(state, channels))) + .thenCompose(this::drainThroughCheckpointTrigger) + .thenRun(() -> setRecoveryCheckpointTrigger(RecoveryCheckpointTrigger.NO_OP)); + } - channelIOExecutor.execute( - () -> - recoverChannelState( - reader, - inputGates, - checkpointingDuringRecoveryEnabled, - physicalChannelsFuture)); + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") // intentional: simplify call-site + private Optional buildDrainer( + Optional state, List channels) { + return state.map(s -> new FetchedChannelStateDrainer(s, channels)); + } - List> recoveredFutures = - checkpointingDuringRecoveryEnabled - ? wireGateConversionWithCheckpointing(inputGates, physicalChannelsFuture) - : wireGateConversion(inputGates); + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") // intentional: simplify call-site + private CompletableFuture drainThroughCheckpointTrigger( + Optional drainer) { + if (drainer.isEmpty()) { + return FutureUtils.completedVoidFuture(); + } + FetchedChannelStateDrainer d = drainer.get(); + return setRecoveryCheckpointTrigger(d).thenRunAsync(() -> drain(d), channelIOExecutor); + } - CompletableFuture allRecoveredFuture = - CompletableFuture.allOf(recoveredFutures.toArray(new CompletableFuture[0])); - allRecoveredFuture.thenRun(mailboxProcessor::suspend); - return allRecoveredFuture; + private CompletableFuture setRecoveryCheckpointTrigger( + RecoveryCheckpointTrigger trigger) { + CompletableFuture future = new CompletableFuture<>(); + mainMailboxExecutor.execute( + () -> { + recoveryCheckpointTrigger = trigger; + future.complete(null); + }, + "update recoveryCheckpointTrigger to " + trigger); + return future; } - /** - * Runs on the {@code channelIOExecutor}: reads input channel state, wires the spill drainer - * when checkpointing-during-recovery is enabled, and drains recovered buffers into the physical - * channels. Setup failures complete {@code physicalChannelsFuture} exceptionally so the - * recovery mailbox loop stops waiting; the drain phase keeps its own handler because the future - * is already completed by then. - */ - private void recoverChannelState( - SequentialChannelStateReader reader, - IndexedInputGate[] inputGates, - boolean checkpointingDuringRecoveryEnabled, - @Nullable CompletableFuture> physicalChannelsFuture) { - FetchedChannelStateDrainer drainer = null; + private CompletableFuture recoverChannelsWithoutCheckpointing( + SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { + recoveryCheckpointTrigger = RecoveryCheckpointTrigger.NO_OP; + List> futures = new ArrayList<>(); + futures.add( + CompletableFuture.runAsync( + () -> readInputChannelState(reader, inputGates), channelIOExecutor)); + for (InputGate inputGate : inputGates) { + futures.add(inputGate.getStateConsumedFuture()); + } + return completeAll(futures).thenRun(() -> requestPartitions(inputGates, false)); + } + + private void readInputChannelState( + SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { try { - reader.readInputData(inputGates, createRecordFilterContext()); - - if (checkpointingDuringRecoveryEnabled) { - Optional producedChannelState = - reader.getProducedChannelState(); - boolean needsRecovery = producedChannelState.isPresent(); - for (IndexedInputGate gate : inputGates) { - gate.setNeedsRecovery(needsRecovery); - } - if (needsRecovery) { - FetchedChannelState channelState = producedChannelState.get(); - drainer = new FetchedChannelStateDrainer(channelState, physicalChannelsFuture); - channelState.release(); - } - } + checkState(reader.readInputData(inputGates, createRecordFilterContext()).isEmpty()); for (IndexedInputGate gate : inputGates) { - gate.finishReadRecoveredState(); + gate.finishReadRecoveredState(); // this is called from IO thread - is that fine? } - // Recovery setup is done: resolve the trigger for the barrier handler (the drainer when - // recovery carries channel state, NO_OP otherwise) and, by the same completion, release - // gate conversion. Completed before any checkpoint can fire during recovery, so the - // handler reads it via getNow. - recoverySetupCompleteFuture.complete( - drainer != null ? drainer : RecoveryCheckpointTrigger.NO_OP); - } catch (Throwable t) { + } catch (Exception e) { asyncExceptionHandler.handleAsyncException( - "Unable to set up recovered channel state", t); - recoverySetupCompleteFuture.completeExceptionally(t); - if (checkpointingDuringRecoveryEnabled) { - if (drainer == null) { - try { - Optional producedChannelState = - reader.getProducedChannelState(); - if (producedChannelState.isPresent()) { - producedChannelState.get().release(); - } - } catch (Throwable ignored) { - // Preserve the original recovery failure. - } - } - physicalChannelsFuture.completeExceptionally(t); - } - return; + "Unable to set up recovered channel state", e); } + } - if (drainer == null) { - return; - } + private Optional fetchChannelState( + SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { try { - drainer.drain(); + return reader.readInputData(inputGates, createRecordFilterContext()); } catch (Throwable t) { asyncExceptionHandler.handleAsyncException( - "Unable to drain recovered channel state", t); - } finally { - try { - drainer.close(); - } catch (Throwable closeError) { - asyncExceptionHandler.handleAsyncException( - "Unable to close FetchedChannelStateDrainer after drain", closeError); - } + "Unable to set up recovered channel state", t); + return Optional.empty(); } } - /** - * Wires each gate's {@code requestPartitions()} to run on the mailbox once its state-consumed - * trigger fires. Used when checkpointing-during-recovery is disabled, so no physical-channel - * conversion needs to be tracked. - * - *

Returns the futures the recovery mailbox loop must await before transitioning to RUNNING. - */ - private List> wireGateConversion(IndexedInputGate[] inputGates) { - List> recoveredFutures = new ArrayList<>(inputGates.length); - for (InputGate inputGate : inputGates) { - CompletableFuture requestPartitionsTrigger = inputGate.getStateConsumedFuture(); - recoveredFutures.add(requestPartitionsTrigger); - requestPartitionsTrigger.thenRun( - () -> - mainMailboxExecutor.execute( - inputGate::requestPartitions, "Input gate request partitions")); + private void drain(FetchedChannelStateDrainer drainer) { + try (FetchedChannelStateDrainer ignored = drainer) { + try { + drainer.drain(); + } catch (Throwable t) { + asyncExceptionHandler.handleAsyncException( + "Unable to drain recovered channel state", t); + } + } catch (Throwable closeError) { + asyncExceptionHandler.handleAsyncException( + "Unable to close FetchedChannelStateDrainer after drain", closeError); } - return recoveredFutures; } - /** - * Wires each gate's {@code requestPartitions()} to run on the mailbox once recovery setup - * completes, and aggregates the per-gate completions into {@code physicalChannelsFuture}. Used - * when checkpointing-during-recovery is enabled. The trigger stays synchronous (no {@code - * *Async}): completing on the {@code channelIOExecutor} that fired {@code - * recoverySetupCompleteFuture} would let the poison mail outrun the suspend callback. - * - *

Returns the futures the recovery mailbox loop must await before transitioning to RUNNING. - */ - private List> wireGateConversionWithCheckpointing( - IndexedInputGate[] inputGates, - CompletableFuture> physicalChannelsFuture) { - List> recoveredFutures = new ArrayList<>(inputGates.length); - // Keep the recovery mailbox loop alive until physical channels are converted; otherwise a - // checkpoint barrier mail could block on the channels future that only a later conversion - // mail can complete. - if (inputGates.length > 0) { - recoveredFutures.add(physicalChannelsFuture); - } - recoveredFutures.add(recoverySetupCompleteFuture); - CompletableFuture gateConverted = new CompletableFuture<>(); - recoverySetupCompleteFuture.thenRun( - () -> - mainMailboxExecutor.execute( - () -> { - try { - for (InputGate inputGate : inputGates) { - inputGate.requestPartitions(); - } - gateConverted.complete(null); - } catch (Throwable t) { - gateConverted.completeExceptionally(t); - throw t; - } - }, - "Input gate request partitions")); - gateConverted - .thenApply(ignored -> collectPhysicalChannels(inputGates)) - .whenComplete( - (physicalChannels, failure) -> { - if (failure != null) { - physicalChannelsFuture.completeExceptionally(failure); - } else { - physicalChannelsFuture.complete(physicalChannels); - } - }); - return recoveredFutures; + private CompletableFuture> requestPartitions( + IndexedInputGate[] inputGates, boolean needsRecovery) { + CompletableFuture> future = new CompletableFuture<>(); + mainMailboxExecutor.execute( + () -> { + try { + for (InputGate inputGate : inputGates) { + inputGate.requestPartitions(needsRecovery); + } + future.complete(collectPhysicalChannels(inputGates)); + } catch (Throwable t) { + future.completeExceptionally(t); + throw t; + } + }, + "Input gate request partitions"); + return future; } private static List collectPhysicalChannels(InputGate[] inputGates) { @@ -1079,19 +1031,10 @@ private static List collectPhysicalChannels(InputGate[] return channels; } - /** - * Returns a trigger that resolves the real implementation lazily: the barrier handler is built - * before the spill drainer exists, so this defers reading {@link #recoverySetupCompleteFuture} - * until a checkpoint actually fires, by which point recovery setup has completed it. Resolving - * at construction time would block; an unresolved future at snapshot time means an invariant - * broke, so it fails loud. - */ public RecoveryCheckpointTrigger getRecoveryCheckpointTrigger() { return cpId -> { - checkState( - recoverySetupCompleteFuture.isDone(), - "Recovery checkpoint trigger is not resolved at checkpoint start."); - return recoverySetupCompleteFuture.getNow(null).snapshotAndInsertBarriers(cpId); + checkState(mailboxProcessor.isMailboxThread()); + return recoveryCheckpointTrigger.snapshotAndInsertBarriers(cpId); }; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java index b011d282cb78c..a143e3735306a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java @@ -54,7 +54,7 @@ void testFilterOnSubmitsDrainAfterConversion() throws Exception { List all = new ArrayList<>(); all.add(chan); FetchedChannelStateDrainer drainer = - new FetchedChannelStateDrainer(state, CompletableFuture.completedFuture(all)); + new FetchedChannelStateDrainer(state, all); ExecutorService channelIOExecutor = Executors.newSingleThreadExecutor(); try { @@ -121,7 +121,7 @@ public void onRecoveredStateConsumed() {} List all = new ArrayList<>(); all.add(chan); FetchedChannelStateDrainer drainer = - new FetchedChannelStateDrainer(state, CompletableFuture.completedFuture(all)); + new FetchedChannelStateDrainer(state, all); CountDownLatch handlerCalled = new CountDownLatch(1); AtomicReference captured = new AtomicReference<>(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java index 32f83760c0104..cd845e2fbf2a1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java @@ -35,7 +35,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -80,7 +79,7 @@ void testDrainAndSnapshotConcurrentAtomicity() throws Exception { all.add(chan1); FetchedChannelStateDrainer drainer = - new FetchedChannelStateDrainer(state, CompletableFuture.completedFuture(all)); + new FetchedChannelStateDrainer(state, all); ExecutorService io = Executors.newSingleThreadExecutor(); AtomicReference drainError = new AtomicReference<>(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java index c78d5b67f4770..d791926e6a0ab 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerTest.java @@ -37,7 +37,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.concurrent.CompletableFuture; import static org.assertj.core.api.Assertions.assertThat; @@ -353,7 +352,7 @@ private FetchedChannelStateDrainer newDrainer( for (int i = 0; i < infoChannelPairs.length; i += 2) { all.add((RecoverableInputChannel) infoChannelPairs[i + 1]); } - return new FetchedChannelStateDrainer(state, CompletableFuture.completedFuture(all)); + return new FetchedChannelStateDrainer(state, all); } private static long extractRecoveryBarrierCheckpointId(Buffer buffer) throws IOException { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java index b3dedda0c72b9..e11379624bf81 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java @@ -152,6 +152,12 @@ void testReadPermutedState() throws Exception { gates, RecordFilterContext.disabled( new String[] {tmpDir.toAbsolutePath().toString()})); + // Mirror the production legacy path (StreamTask#readInputChannelState): the + // EndOfInputChannelStateEvent sentinel that completes stateConsumedFuture is + // enqueued by finishReadRecoveredState(), not by readInputData() itself. + for (InputGate gate : gates) { + gate.finishReadRecoveredState(); + } assertBuffersEquals(inputChannelsData, collectBuffers(gates)); assertConsumed(gates); }); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java index 098d86587f93e..b064511e6e6f6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java @@ -94,6 +94,7 @@ import org.apache.flink.util.concurrent.Executors; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -112,6 +113,7 @@ * checks correct working of different policies how tasks deal with checkpoint failures (fail task, * decline checkpoint and continue). */ +@Timeout(120) class TaskCheckpointingBehaviourTest { private static final OneShotLatch IN_CHECKPOINT_LATCH = new OneShotLatch(); From 7aa50f35a9942c677a7b11e45d5822c4968531e8 Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Fri, 26 Jun 2026 20:42:40 +0200 Subject: [PATCH 09/16] [FLINK-38544][checkpoint] Fix recovery race: don't gate suspend on channel-state read The async-recovery rewrite put the channel-state read future (runAsync(readInputChannelState)) into the completeAll(...) set that gates the recovery-completion suspend() poison mail. That future is never already-complete, so suspend() was deferred past the start of the restore mailbox loop. The loop then ran the default action (record processing) before recovery finished: - records were processed before gate conversion/requestPartitions, losing them (MultipleInputStreamTaskTest: 10 -> 7); - processInput hitting END_OF_INPUT during restore called suspend() itself, exiting the loop before the recovery future was done, so restoreInternal's checkState(allGatesRecoveredFuture.isDone()) threw "Mailbox loop interrupted before recovery was finished" (StreamTaskTest.testProcessWith*); and - on downscale, recovery could stall entirely (UnalignedCheckpointRescaleITCase hang). Run readInputChannelState as a fire-and-forget feeder instead, and gate suspend()/recoveryCompletionFuture on the gates' stateConsumedFutures only. A stateConsumedFuture completes only once the consumer drains the end-of-state sentinel that readInputChannelState pushes, so gating on it already implies the read finished, and requestPartitions still runs after the read. For a task with no recovered state the futures are already complete, so suspend() is enqueued before the loop runs the default action -- restoring the pre-rewrite ordering. Unlike suspending the default action, this does not block the consumer, so recovery that must drain real channel state still completes. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../flink/streaming/runtime/tasks/StreamTask.java | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 5f6b9dcbe5e9a..52d5f0a5b8f4b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -949,10 +949,17 @@ private CompletableFuture setRecoveryCheckpointTrigger( private CompletableFuture recoverChannelsWithoutCheckpointing( SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { recoveryCheckpointTrigger = RecoveryCheckpointTrigger.NO_OP; + // Feed recovered channel state on the IO thread. This is intentionally NOT part of the + // completion gate below: a gate's stateConsumedFuture only completes once the consumer (the + // default action, running in the restore mailbox loop) has drained the end-of-state + // sentinel that readInputChannelState pushes, so gating on stateConsumedFuture already + // implies the read has finished. Including the read future in the gate would defer + // suspend() past the restore loop's first default action, letting input be processed before + // recovery completes -- causing record loss and "Mailbox loop interrupted before recovery + // was finished". Read failures are surfaced via asyncExceptionHandler inside + // readInputChannelState. + channelIOExecutor.execute(() -> readInputChannelState(reader, inputGates)); List> futures = new ArrayList<>(); - futures.add( - CompletableFuture.runAsync( - () -> readInputChannelState(reader, inputGates), channelIOExecutor)); for (InputGate inputGate : inputGates) { futures.add(inputGate.getStateConsumedFuture()); } From af95a22c0aba5494ad8e14bef890b01cbc342aae Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Fri, 26 Jun 2026 21:09:48 +0200 Subject: [PATCH 10/16] [FLINK-38544][checkpoint] Don't require an IOManager for the disabled record-filter context When checkpointing-during-recovery is disabled, the RecordFilterContext never spills, so it needs no spill directories. createRecordFilterContext nevertheless dereferenced getEnvironment().getIOManager().getSpillingDirectoriesPaths(), which NPEs in minimal environments that return a null IOManager (e.g. DummyEnvironment-based StreamTaskTest cases). On the channelIOExecutor during recovery that NPE was routed to handleAsyncException -> failExternally, which DummyEnvironment rejects with UnsupportedOperationException, escalating to the fatal error handler and crashing the surefire fork. Pass an empty spill-directory array for the disabled context instead of touching the IOManager. The async-recovery rewrite introduced this IOManager access during recovery; these StreamTaskTest cases passed before it. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../apache/flink/streaming/runtime/tasks/StreamTask.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 52d5f0a5b8f4b..7e1e0215f0e20 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -2106,8 +2106,11 @@ protected RecordFilterContext createRecordFilterContext() { boolean checkpointingDuringRecoveryEnabled = CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); if (!checkpointingDuringRecoveryEnabled) { - return RecordFilterContext.disabled( - getEnvironment().getIOManager().getSpillingDirectoriesPaths()); + // A disabled context never spills, so it needs no spill directories. Don't touch the + // IOManager here -- recovery then works even in minimal environments that don't provide + // one (e.g. DummyEnvironment-based tests), instead of NPEing and routing the failure to + // failExternally. + return RecordFilterContext.disabled(new String[0]); } ClassLoader cl = getUserCodeClassLoader(); From c3c0b1e85abc2cb00e68ea7b037955c16971cb1a Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Fri, 26 Jun 2026 19:07:37 +0200 Subject: [PATCH 11/16] [FLINK-38544] Remove recovery flags from the gate API Drop the needsRecovery and checkpointingDuringRecoveryEnabled fields and their accessors from the input-gate API. needsRecovery is now passed as a parameter (requestPartitions(boolean) / convertRecoveredInputChannels(boolean)) and checkpointing-during-recovery is read from the job config at the call site instead of being pushed onto each gate. Updates the gate mocks and test builders accordingly. --- .../partition/consumer/IndexedInputGate.java | 14 ------ .../network/partition/consumer/InputGate.java | 9 ++++ .../consumer/RecoveredInputChannel.java | 14 +----- .../partition/consumer/SingleInputGate.java | 36 +++++---------- .../partition/consumer/UnionInputGate.java | 7 ++- .../taskmanager/InputGateWithMetrics.java | 25 ++-------- .../LocalRecoveredInputChannelTest.java | 2 +- ...BufferBlockingHeapFallbackRemovedTest.java | 18 ++------ .../consumer/RecoveredInputChannelTest.java | 46 ++++--------------- .../RemoteRecoveredInputChannelTest.java | 2 +- .../consumer/SingleInputGateBuilder.java | 18 -------- .../runtime/io/MockIndexedInputGate.java | 16 ------- .../streaming/runtime/io/MockInputGate.java | 16 ------- .../AlignedCheckpointsMassiveRandomTest.java | 16 ------- 14 files changed, 47 insertions(+), 192 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java index 95191f7175573..5daa277cd9b34 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java @@ -74,18 +74,4 @@ public void convertToPriorityEvent(int channelIndex, int sequenceNumber) throws public abstract ResultPartitionType getConsumedPartitionType(); public abstract void triggerDebloating(); - - /** Sets whether unaligned checkpointing during recovery is enabled. */ - public abstract void setCheckpointingDuringRecoveryEnabled(boolean enabled); - - /** Returns whether unaligned checkpointing during recovery is enabled. */ - public abstract boolean isCheckpointingDuringRecoveryEnabled(); - - /** - * Sets whether converted physical channels start in recovery. Must be published before the - * buffer-filtering completion future is completed. - */ - public abstract void setNeedsRecovery(boolean enabled); - - public abstract boolean needsRecovery(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java index b70c284a21a61..142f45cf9bcc5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java @@ -198,6 +198,15 @@ public String toString() { public abstract void requestPartitions() throws IOException; + /** + * Requests the partitions. {@code needsRecovery} controls whether converted physical channels + * start in recovery (i.e. with no credit / no floating buffers until the recovered channel + * state has been drained). The default implementation ignores the flag. + */ + public void requestPartitions(boolean needsRecovery) throws IOException { + requestPartitions(); + } + public abstract CompletableFuture getStateConsumedFuture(); public abstract void finishReadRecoveredState() throws IOException; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java index 82ae84a6c91f2..d291e4ea8e996 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java @@ -104,11 +104,6 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { } public final InputChannel toInputChannel(boolean needsRecovery) throws IOException { - if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { - Preconditions.checkState( - stateConsumedFuture.isDone(), "recovered state is not fully consumed"); - } - // With checkpointing-during-recovery, data is spilled instead of queued here. synchronized (receivedBuffers) { Preconditions.checkState(receivedBuffers.isEmpty(), "Received buffer should be empty."); @@ -165,14 +160,9 @@ public void onRecoveredStateBuffer(Buffer buffer) { } public void finishReadRecoveredState() throws IOException { - // In legacy recovery, adding the sentinel must be atomic under receivedBuffers lock to - // ensure the sentinel is enqueued before any concurrent reader can observe an empty queue - // and miss the EndOfInputChannelStateEvent that completes stateConsumedFuture. synchronized (receivedBuffers) { - if (!inputGate.isCheckpointingDuringRecoveryEnabled()) { - onRecoveredStateBuffer( - EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); - } + onRecoveredStateBuffer( + EventSerializer.toBuffer(EndOfInputChannelStateEvent.INSTANCE, false)); } bufferManager.releaseFloatingBuffers(); LOG.debug("{}/{} finished recovering input.", inputGate.getOwningTaskName(), channelInfo); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 50e359b4d4e32..54b0bd928e417 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -243,9 +243,6 @@ public class SingleInputGate extends IndexedInputGate { */ private final int[] endOfPartitions; - private volatile boolean checkpointingDuringRecoveryEnabled = false; - - private volatile boolean needsRecovery = false; public SingleInputGate( String owningTaskName, @@ -334,27 +331,12 @@ public CompletableFuture getStateConsumedFuture() { } @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) { - this.checkpointingDuringRecoveryEnabled = enabled; - } - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return checkpointingDuringRecoveryEnabled; - } - - @Override - public void setNeedsRecovery(boolean enabled) { - this.needsRecovery = enabled; - } - - @Override - public boolean needsRecovery() { - return needsRecovery; + public void requestPartitions() { + requestPartitions(false); } @Override - public void requestPartitions() { + public void requestPartitions(boolean needsRecovery) { synchronized (requestLock) { if (!requestedPartitionsFlag) { if (closeFuture.isDone()) { @@ -373,7 +355,7 @@ public void requestPartitions() { numInputChannels, numberOfInputChannels)); } - convertRecoveredInputChannels(); + convertRecoveredInputChannels(needsRecovery); internalRequestPartitions(); } @@ -387,12 +369,18 @@ public void requestPartitions() { } } + @VisibleForTesting + public void convertRecoveredInputChannels() { + convertRecoveredInputChannels(false); + } + /** * Converts all {@link RecoveredInputChannel}s to their real channel types ({@link - * LocalInputChannel} or {@link RemoteInputChannel}). + * LocalInputChannel} or {@link RemoteInputChannel}). {@code needsRecovery} controls whether the + * converted physical channels start in recovery. */ @VisibleForTesting - public void convertRecoveredInputChannels() { + public void convertRecoveredInputChannels(boolean needsRecovery) { LOG.debug("Converting recovered input channels ({} channels)", getNumberOfInputChannels()); for (Map inputChannelsForCurrentPartition : inputChannels.values()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index fff12185156d4..b7a708f38f27e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -359,8 +359,13 @@ public CompletableFuture getStateConsumedFuture() { @Override public void requestPartitions() throws IOException { + requestPartitions(false); + } + + @Override + public void requestPartitions(boolean needsRecovery) throws IOException { for (InputGate inputGate : inputGatesByGateIndex.values()) { - inputGate.requestPartitions(); + inputGate.requestPartitions(needsRecovery); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java index 9b3abd611b78f..5d1ca9c680b81 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java @@ -130,6 +130,11 @@ public void requestPartitions() throws IOException { inputGate.requestPartitions(); } + @Override + public void requestPartitions(boolean needsRecovery) throws IOException { + inputGate.requestPartitions(needsRecovery); + } + @Override public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { inputGate.setChannelStateWriter(channelStateWriter); @@ -165,26 +170,6 @@ public void finishReadRecoveredState() throws IOException { inputGate.finishReadRecoveredState(); } - @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) { - inputGate.setCheckpointingDuringRecoveryEnabled(enabled); - } - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return inputGate.isCheckpointingDuringRecoveryEnabled(); - } - - @Override - public void setNeedsRecovery(boolean enabled) { - inputGate.setNeedsRecovery(enabled); - } - - @Override - public boolean needsRecovery() { - return inputGate.needsRecovery(); - } - private BufferOrEvent updateMetrics(BufferOrEvent bufferOrEvent) { int incomingDataSize = bufferOrEvent.getSize(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java index a2083468b811a..9559ef316c5a1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java @@ -31,7 +31,7 @@ class LocalRecoveredInputChannelTest { @Test void testToInputChannelRequiresEmptyRecoveredBuffers() throws Exception { SingleInputGate inputGate = - new SingleInputGateBuilder().setCheckpointingDuringRecoveryEnabled(true).build(); + new SingleInputGateBuilder().build(); LocalRecoveredInputChannel recoveredChannel = InputChannelBuilder.newBuilder() .setStateWriter(ChannelStateWriter.NO_OP) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java index 1b9854361b05e..32247c0cc7799 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelRequestBufferBlockingHeapFallbackRemovedTest.java @@ -85,9 +85,8 @@ void testFilterOnPathTakesSameRouteAsFilterOff() throws Exception { int totalSegments = 4; pool = new NetworkBufferPool(totalSegments, MemoryManager.DEFAULT_PAGE_SIZE); - Buffer filterOnBuf = buildChannel(pool, exclusivePerChannel, true).requestBufferBlocking(); - Buffer filterOffBuf = - buildChannel(pool, exclusivePerChannel, false).requestBufferBlocking(); + Buffer filterOnBuf = buildChannel(pool, exclusivePerChannel).requestBufferBlocking(); + Buffer filterOffBuf = buildChannel(pool, exclusivePerChannel).requestBufferBlocking(); // Both must come from the pool — the BufferManager-owned recycler, not the // FreeingBufferRecycler the heap-fallback used. @@ -104,20 +103,9 @@ void testFilterOnPathTakesSameRouteAsFilterOff() throws Exception { private RecoveredInputChannel buildChannel( NetworkBufferPool segmentProvider, int exclusivePerChannel) { - return buildChannel(segmentProvider, exclusivePerChannel, true); - } - - private RecoveredInputChannel buildChannel( - NetworkBufferPool segmentProvider, - int exclusivePerChannel, - boolean checkpointingDuringRecoveryEnabled) { try { SingleInputGate inputGate = - new SingleInputGateBuilder() - .setSegmentProvider(segmentProvider) - .setCheckpointingDuringRecoveryEnabled( - checkpointingDuringRecoveryEnabled) - .build(); + new SingleInputGateBuilder().setSegmentProvider(segmentProvider).build(); return new RecoveredInputChannel( inputGate, 0, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java index f606bb3eec37c..ba78b5c2aed6b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java @@ -60,31 +60,19 @@ void testCheckpointStartImpossible() { } @Test - void testToInputChannelAllowedWhenBufferFilteringCompleteAndConfigEnabled() throws IOException { - // When config is enabled, conversion is allowed after finishReadRecoveredState() - // without requiring stateConsumedFuture to be done. - TestableRecoveredInputChannel channel = buildTestableChannel(true); - - channel.finishReadRecoveredState(); - assertThat(channel.getStateConsumedFuture()).isNotDone(); - - // Conversion should now succeed (no exception) - InputChannel converted = channel.toInputChannel(true); - assertThat(converted).isNotNull(); - } - - @Test - void testToInputChannelAllowedWhenStateConsumedAndConfigDisabled() throws IOException { - // When config is disabled, conversion requires stateConsumedFuture to be done. + void testToInputChannelRejectedWhileRecoveredStateUnconsumed() throws IOException { + // Conversion is rejected while recovered state is still queued: finishReadRecoveredState() + // enqueues the EndOfInputChannelStateEvent sentinel, so receivedBuffers is non-empty until + // it is consumed. The empty-queue check thus also guarantees stateConsumedFuture is done. TestableRecoveredInputChannel channel = buildTestableChannel(false); channel.finishReadRecoveredState(); assertThat(channel.getStateConsumedFuture()).isNotDone(); - // Conversion should fail because stateConsumedFuture is not done + // Conversion fails because the sentinel is still queued. assertThatThrownBy(() -> channel.toInputChannel(true)) .isInstanceOf(IllegalStateException.class) - .hasMessageContaining("recovered state is not fully consumed"); + .hasMessageContaining("Received buffer should be empty"); // Consuming the EndOfInputChannelStateEvent should complete the future. // getNextBuffer() returns empty when it encounters the event internally. @@ -121,23 +109,9 @@ void testStateConsumedFutureCompletesAfterLegacySentinelIsConsumed() throws IOEx assertThat(channel.getStateConsumedFuture()).isDone(); } - @Test - void testStateConsumedFutureDoesNotCompleteWithoutLegacySentinel() throws IOException { - RecoveredInputChannel channel = buildChannel(true); - - channel.finishReadRecoveredState(); - - assertThat(channel.getNextBuffer()).isNotPresent(); - assertThat(channel.getStateConsumedFuture()).isNotDone(); - } - private RecoveredInputChannel buildChannel(boolean checkpointingDuringRecoveryEnabled) { try { - SingleInputGate inputGate = - new SingleInputGateBuilder() - .setCheckpointingDuringRecoveryEnabled( - checkpointingDuringRecoveryEnabled) - .build(); + SingleInputGate inputGate = new SingleInputGateBuilder().build(); return new RecoveredInputChannel( inputGate, 0, @@ -161,11 +135,7 @@ protected InputChannel toInputChannelInternal(boolean needsRecovery) { private TestableRecoveredInputChannel buildTestableChannel( boolean checkpointingDuringRecoveryEnabled) { try { - SingleInputGate inputGate = - new SingleInputGateBuilder() - .setCheckpointingDuringRecoveryEnabled( - checkpointingDuringRecoveryEnabled) - .build(); + SingleInputGate inputGate = new SingleInputGateBuilder().build(); return new TestableRecoveredInputChannel(inputGate); } catch (Exception e) { throw new AssertionError("channel creation failed", e); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java index 938d040093e32..339ccced04b75 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java @@ -31,7 +31,7 @@ class RemoteRecoveredInputChannelTest { @Test void testToInputChannelRequiresEmptyRecoveredBuffers() throws Exception { SingleInputGate inputGate = - new SingleInputGateBuilder().setCheckpointingDuringRecoveryEnabled(true).build(); + new SingleInputGateBuilder().build(); RemoteRecoveredInputChannel recoveredChannel = InputChannelBuilder.newBuilder() .setStateWriter(ChannelStateWriter.NO_OP) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java index a1d1afbfc27c7..e4a4c289dc6e8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java @@ -83,10 +83,6 @@ public class SingleInputGateBuilder { private TieredStorageConsumerClient tieredStorageConsumerClient = null; - private boolean isCheckpointingDuringRecoveryEnabled = false; - - private boolean isNeedsRecovery = false; - public SingleInputGateBuilder setPartitionProducerStateProvider( PartitionProducerStateProvider partitionProducerStateProvider) { @@ -171,16 +167,6 @@ public SingleInputGateBuilder setTieredStorageConsumerClient( return this; } - public SingleInputGateBuilder setCheckpointingDuringRecoveryEnabled(boolean enabled) { - this.isCheckpointingDuringRecoveryEnabled = enabled; - return this; - } - - public SingleInputGateBuilder setNeedsRecovery(boolean enabled) { - this.isNeedsRecovery = enabled; - return this; - } - public SingleInputGate build() { SingleInputGate gate = new SingleInputGate( @@ -196,9 +182,6 @@ public SingleInputGate build() { bufferSize, createThroughputCalculator.apply(bufferDebloatConfiguration), maybeCreateBufferDebloater(gateIndex)); - // Propagate before channel construction so RecoverableInputChannel implementations read - // the intended flag in their constructor and initialise their recovery state correctly. - gate.setNeedsRecovery(isNeedsRecovery); if (channelFactory != null) { gate.setInputChannels( IntStream.range(0, numberOfChannels) @@ -212,7 +195,6 @@ public SingleInputGate build() { .toArray(InputChannel[]::new)); } gate.setTieredStorageService(null, tieredStorageConsumerClient, null); - gate.setCheckpointingDuringRecoveryEnabled(isCheckpointingDuringRecoveryEnabled); return gate; } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java index 1c58310cdbef1..53bba67f7a2be 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java @@ -142,20 +142,4 @@ public ResultPartitionType getConsumedPartitionType() { @Override public void triggerDebloating() {} - - @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return false; - } - - @Override - public void setNeedsRecovery(boolean enabled) {} - - @Override - public boolean needsRecovery() { - return false; - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java index 1d5729231378a..47e3b79a77fc1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java @@ -204,20 +204,4 @@ public int getGateIndex() { public List getUnfinishedChannels() { return Collections.emptyList(); } - - @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return false; - } - - @Override - public void setNeedsRecovery(boolean enabled) {} - - @Override - public boolean needsRecovery() { - return false; - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java index 2448a2ac524f6..ad792d221c2ef 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java @@ -286,21 +286,5 @@ public int getGateIndex() { public List getUnfinishedChannels() { return Collections.emptyList(); } - - @Override - public void setCheckpointingDuringRecoveryEnabled(boolean enabled) {} - - @Override - public boolean isCheckpointingDuringRecoveryEnabled() { - return false; - } - - @Override - public void setNeedsRecovery(boolean enabled) {} - - @Override - public boolean needsRecovery() { - return false; - } } } From 8aeb1ad24ae705ab405f4fe4d5522ea26bf78e82 Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Fri, 26 Jun 2026 19:07:51 +0200 Subject: [PATCH 12/16] [FLINK-38544] Decline checkpoint as TASK_NOT_READY when recovery barrier missing When collectPreRecoveryBarrier finds no during-recovery sentinel for a checkpoint in a still-recovering channel, decline with CHECKPOINT_DECLINED_TASK_NOT_READY (not an IOException). That reason is not counted against the tolerable-failure threshold, so the checkpoint is deferred and retried; the recovered buffers stay queued and are captured by a later checkpoint, so no in-flight data is lost. --- .../partition/consumer/LocalInputChannel.java | 19 ++++++++++++----- .../consumer/RemoteInputChannel.java | 19 ++++++++++++----- .../consumer/LocalInputChannelTest.java | 20 ++++++++++++------ .../consumer/RemoteInputChannelTest.java | 21 ++++++++++++------- 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index e0ce62009d5dc..05071d4cadb68 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -282,7 +282,8 @@ private int nextRecoverySequenceNumber() { * @throws IOException if no sentinel matching {@code checkpointId} is found (the snapshot * protocol guarantees one must be present while the channel is in recovery). */ - private List collectPreRecoveryBarrier(long checkpointId) throws IOException { + private List collectPreRecoveryBarrier(long checkpointId) + throws IOException, CheckpointException { assert Thread.holdsLock(recoveredBuffers); List retained = new ArrayList<>(); try { @@ -303,11 +304,19 @@ private List collectPreRecoveryBarrier(long checkpointId) throws IOExcep throw e; } releaseRetainedBuffers(retained); - throw new IOException( - "Missing RecoveryCheckpointBarrier for checkpoint " + // The during-recovery sentinel for this checkpoint was never inserted into this channel + // (the recovery checkpoint trigger had already transitioned away from the drainer while + // this channel was still in recovery). The channel is simply not ready to snapshot + // recovered state for this checkpoint yet, so decline as TASK_NOT_READY: that is not + // counted against the tolerable-failure threshold, so the checkpoint is deferred and + // retried instead of failing the job. The recovered buffers remain queued and are + // captured by a later checkpoint, so no in-flight data is lost. + throw new CheckpointException( + "RecoveryCheckpointBarrier for checkpoint " + checkpointId - + " in recoveredBuffers for channel " - + getChannelInfo()); + + " not yet present in channel " + + getChannelInfo(), + CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY); } private static void releaseRetainedBuffers(List retained) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 0f7b582ab3c23..1bb2401d4a543 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -928,7 +928,8 @@ public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointExcept * protocol guarantees one must be present while the channel is in recovery). */ @GuardedBy("receivedBuffers") - private List collectPreRecoveryBarrier(long checkpointId) throws IOException { + private List collectPreRecoveryBarrier(long checkpointId) + throws IOException, CheckpointException { assert Thread.holdsLock(receivedBuffers); List retained = new ArrayList<>(); SequenceBuffer sentinel = null; @@ -954,11 +955,19 @@ private List collectPreRecoveryBarrier(long checkpointId) throws IOExcep } if (sentinel == null) { releaseRetainedBuffers(retained); - throw new IOException( - "Missing RecoveryCheckpointBarrier for checkpoint " + // The during-recovery sentinel for this checkpoint was never inserted into this channel + // (the recovery checkpoint trigger had already transitioned away from the drainer while + // this channel was still in recovery). The channel is simply not ready to snapshot + // recovered state for this checkpoint yet, so decline as TASK_NOT_READY: that is not + // counted against the tolerable-failure threshold, so the checkpoint is deferred and + // retried instead of failing the job. The recovered buffers remain queued and are + // captured by a later checkpoint, so no in-flight data is lost. + throw new CheckpointException( + "RecoveryCheckpointBarrier for checkpoint " + checkpointId - + " in receivedBuffers for channel " - + getChannelInfo()); + + " not yet present in channel " + + getChannelInfo(), + CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY); } // receivedBuffers is a PrioritizedDeque whose iterator() is read-only; remove the matched // sentinel by identity through its mutable removal API. diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index 9861d28024c07..a84c5940c9df7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -20,6 +20,7 @@ import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; @@ -1196,7 +1197,7 @@ void testCheckpointStartedScansRecoveredBuffersUpToBarrier() throws Exception { } @Test - void testCheckpointStartedFailsWhenRecoveryBarrierIsMissing() throws Exception { + void testCheckpointStartedDeclinesAsNotReadyWhenRecoveryBarrierIsMissing() throws Exception { SingleInputGate inputGate = new SingleInputGateBuilder().build(); RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); LocalInputChannel channel = newPushOnlyLocalChannel(inputGate, stateWriter); @@ -1210,12 +1211,19 @@ void testCheckpointStartedFailsWhenRecoveryBarrierIsMissing() throws Exception { CheckpointOptions.unaligned(CheckpointType.CHECKPOINT, getDefault()); stateWriter.start(1L, options); + // A missing RecoveryCheckpointBarrier means the channel is not yet ready to snapshot + // recovered state for this checkpoint, so it declines as TASK_NOT_READY (not a fatal + // CHECKPOINT_DECLINED): the checkpoint is deferred/retried and the recovered buffer is + // neither dropped nor persisted. assertThatThrownBy(() -> channel.checkpointStarted(new CheckpointBarrier(1L, 0L, options))) - .isInstanceOf(CheckpointException.class) - .hasMessageContaining("Failed to extract recovered buffers for checkpoint 1") - .hasRootCauseMessage( - "Missing RecoveryCheckpointBarrier for checkpoint 1 in recoveredBuffers for channel " - + channel.getChannelInfo()); + .isInstanceOfSatisfying( + CheckpointException.class, + e -> + assertThat(e.getCheckpointFailureReason()) + .isEqualTo( + CheckpointFailureReason + .CHECKPOINT_DECLINED_TASK_NOT_READY)) + .hasMessageContaining("not yet present in channel"); assertThat(b1.refCnt()).isEqualTo(refCntBefore); assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isEmpty(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index 58223dea191ae..1bd6489dc9d33 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -23,6 +23,7 @@ import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.checkpoint.CheckpointFailureReason; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointType; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; @@ -2345,7 +2346,6 @@ void testMoreAvailableNoneWhenLastRecoveredBufferAndDrainNotFinished() throws Ex new SingleInputGateBuilder() .setBufferPoolFactory(networkBufferPool.createBufferPool(1, 4)) .setSegmentProvider(networkBufferPool) - .setNeedsRecovery(true) .setChannelFactory( (builder, gate) -> builder.setNeedsRecovery(true).buildRemoteChannel(gate)) @@ -2426,7 +2426,7 @@ void testCheckpointStartedScansRecoveredBuffersUpToBarrier() throws Exception { } @Test - void testCheckpointStartedFailsWhenRecoveryBarrierIsMissing() throws Exception { + void testCheckpointStartedDeclinesAsNotReadyWhenRecoveryBarrierIsMissing() throws Exception { SingleInputGate inputGate = createSingleInputGate(1); RecordingChannelStateWriter stateWriter = new RecordingChannelStateWriter(); RemoteInputChannel channel = @@ -2442,13 +2442,20 @@ void testCheckpointStartedFailsWhenRecoveryBarrierIsMissing() throws Exception { stateWriter.start(1L, UNALIGNED); + // A missing RecoveryCheckpointBarrier means the channel is not yet ready to snapshot + // recovered state for this checkpoint, so it declines as TASK_NOT_READY (not a fatal + // CHECKPOINT_DECLINED): the checkpoint is deferred/retried and the recovered buffer is + // neither dropped nor persisted. assertThatThrownBy( () -> channel.checkpointStarted(new CheckpointBarrier(1L, 0L, UNALIGNED))) - .isInstanceOf(CheckpointException.class) - .hasMessageContaining("Failed to extract recovered buffers for checkpoint 1") - .hasRootCauseMessage( - "Missing RecoveryCheckpointBarrier for checkpoint 1 in receivedBuffers for channel " - + channel.getChannelInfo()); + .isInstanceOfSatisfying( + CheckpointException.class, + e -> + assertThat(e.getCheckpointFailureReason()) + .isEqualTo( + CheckpointFailureReason + .CHECKPOINT_DECLINED_TASK_NOT_READY)) + .hasMessageContaining("not yet present in channel"); assertThat(b1.refCnt()).isEqualTo(refCntBefore); assertThat(stateWriter.getAddedInput().get(channel.getChannelInfo())).isEmpty(); } From 6fd18abe42eb0082c6f708971e7ac23e494041f1 Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Fri, 26 Jun 2026 19:07:59 +0200 Subject: [PATCH 13/16] [FLINK-38544] Tolerate empty tmp dirs in disabled RecordFilterContext A disabled RecordFilterContext never spills, so it does not need spill directories. Only require non-empty tmpDirectories when checkpointing-during-recovery is enabled; otherwise tolerate a null/empty value (e.g. an environment without IOManager spilling directories). --- .../io/recovery/RecordFilterContext.java | 12 +++++-- .../io/recovery/RecordFilterContextTest.java | 35 +++++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java index 165e66a278d9d..ae0e364ab712f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java @@ -131,9 +131,17 @@ public RecordFilterContext( this.rescalingDescriptor = checkNotNull(rescalingDescriptor); this.subtaskIndex = subtaskIndex; this.maxParallelism = maxParallelism; - checkArgument(checkNotNull(tmpDirectories).length > 0, "tmpDirectories must not be empty"); - this.tmpDirectories = tmpDirectories.clone(); this.checkpointingDuringRecoveryEnabled = checkpointingDuringRecoveryEnabled; + if (checkpointingDuringRecoveryEnabled) { + // tmpDirectories are only used by the spilling (checkpointing-during-recovery) path. + checkArgument( + checkNotNull(tmpDirectories).length > 0, "tmpDirectories must not be empty"); + this.tmpDirectories = tmpDirectories.clone(); + } else { + // A disabled context never spills, so it needs no spill directories; tolerate a + // null/empty value (e.g. an environment without IOManager spilling directories). + this.tmpDirectories = tmpDirectories == null ? new String[0] : tmpDirectories.clone(); + } checkArgument( memorySegmentSize > 0, "memorySegmentSize must be positive: %s", memorySegmentSize); this.memorySegmentSize = memorySegmentSize; diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java index 27477ca8c6ef9..bd17c8d33f0db 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java @@ -84,7 +84,9 @@ void testGetInputConfigThrowsForInvalidIndex() { } @Test - void testNullOrEmptyTmpDirectoriesRejected() { + void testEnabledContextRejectsNullOrEmptyTmpDirectories() { + // When checkpointing-during-recovery is enabled, the spilling path needs spill + // directories, so null/empty tmpDirectories are rejected. assertThatThrownBy( () -> new RecordFilterContext( @@ -93,7 +95,7 @@ void testNullOrEmptyTmpDirectoriesRejected() { 0, 128, null, - false, + true, MemoryManager.DEFAULT_PAGE_SIZE)) .isInstanceOf(NullPointerException.class); assertThatThrownBy( @@ -104,11 +106,38 @@ void testNullOrEmptyTmpDirectoriesRejected() { 0, 128, new String[0], - false, + true, MemoryManager.DEFAULT_PAGE_SIZE)) .isInstanceOf(IllegalArgumentException.class); } + @Test + void testDisabledContextToleratesNullOrEmptyTmpDirectories() { + // A disabled context never spills, so it needs no spill directories: null/empty are + // tolerated and normalized to an empty array. + RecordFilterContext fromNull = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + null, + false, + MemoryManager.DEFAULT_PAGE_SIZE); + assertThat(fromNull.getTmpDirectories()).isEmpty(); + + RecordFilterContext fromEmpty = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + new String[0], + false, + MemoryManager.DEFAULT_PAGE_SIZE); + assertThat(fromEmpty.getTmpDirectories()).isEmpty(); + } + @Test void testIsAmbiguousWhenDisabled() { // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 0 is ambiguous) From 49354466f59332d3e666a3a01aa9f31124752b06 Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Fri, 26 Jun 2026 19:08:12 +0200 Subject: [PATCH 14/16] [FLINK-38544] Review notes (DROP before merge) Inline // review / // review nit / // review todo annotations left for the reviewer; drop this commit before merge. --- .../runtime/checkpoint/channel/FetchedChannelState.java | 2 +- .../checkpoint/channel/FetchedChannelStateDrainer.java | 4 ++-- .../checkpoint/channel/RecoveryCheckpointBarrier.java | 1 + .../netty/CreditBasedSequenceNumberingViewReader.java | 2 +- .../io/network/partition/consumer/LocalInputChannel.java | 1 + .../network/partition/consumer/RecoverableInputChannel.java | 4 ++++ .../io/network/partition/consumer/RemoteInputChannel.java | 5 +++-- .../io/network/partition/consumer/SingleInputGate.java | 3 +++ .../org/apache/flink/streaming/runtime/tasks/StreamTask.java | 2 +- 9 files changed, 17 insertions(+), 7 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java index 68116cad9fec4..5e42b8217b9dc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelState.java @@ -107,7 +107,7 @@ public void release() throws IOException { if (refCount.decrementAndGet() == 0) { if (cleanedUp.compareAndSet(false, true)) { closed = true; - deleteAllFiles(); + deleteAllFiles(); // review todo: threads / call-sites } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java index 950a65e60b01e..58ebb1d733be1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainer.java @@ -89,7 +89,7 @@ private static final class ResolvedChannels { *

Disk reads and buffer allocations happen outside the lock; only the "deliver + commit" * pair is locked to guarantee atomicity with snapshot. */ - public void drain() throws IOException, InterruptedException { + public void drain() throws IOException, InterruptedException { // review: enforce called once channelState.release(); Optional next; while ((next = rootReader.nextSegment()).isPresent()) { @@ -161,7 +161,7 @@ private static int fill(Buffer buf, InputStream in, int remaining) throws IOExce } // Do not use try-with-resources: ChannelStateByteBuffer.close() recycles the buffer, // but the buffer is still owned by the caller here. - ChannelStateByteBuffer view = ChannelStateByteBuffer.wrap(buf); + ChannelStateByteBuffer view = ChannelStateByteBuffer.wrap(buf); // review: try-without ? return view.writeBytes(in, remaining); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java index 984bfd42312f6..d263168b0119c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveryCheckpointBarrier.java @@ -23,6 +23,7 @@ import org.apache.flink.runtime.event.RuntimeEvent; /** Task-local event marking the recovery-state cut for a recovery checkpoint. */ +// review nit: checkpoint during recovery @Internal public final class RecoveryCheckpointBarrier extends RuntimeEvent { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java index c1e00eca802d9..9498353f4d09c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedSequenceNumberingViewReader.java @@ -47,7 +47,7 @@ *

It also keeps track of available buffers and notifies the outbound handler about * non-emptiness, similar to the {@link LocalInputChannel}. */ -class CreditBasedSequenceNumberingViewReader +class CreditBasedSequenceNumberingViewReader // review nit: can we split THIS commit into remote and // local input channel changes? implements BufferAvailabilityListener, NetworkSequenceViewReader { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index 05071d4cadb68..b4a7e2bd20b1a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -162,6 +162,7 @@ public LocalInputChannel( this.inRecovery = needsRecovery; this.bufferManager = needsRecovery + // review nit: false for consistency? ? new BufferManager(inputGate.getMemorySegmentProvider(), this, 0, true) : null; this.networkBuffersPerChannel = networkBuffersPerChannel; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java index 40a4ab27a0908..9ccca71237f1e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoverableInputChannel.java @@ -24,6 +24,9 @@ import java.io.IOException; /** Physical input channel that can receive recovered buffers pushed by the spill drain. */ +// review: please rephrase: physical, spill, drain +// review: should this be in the previous commit? according to THIS commit message +// review todo impl @Internal public interface RecoverableInputChannel { @@ -59,5 +62,6 @@ public interface RecoverableInputChannel { * recovery, release any upstream events held back during recovery, and reopen the upstream so * live data may flow again. */ + // review: is this guaranteed to be called after finishRecoveredBufferDelivery? void onRecoveredStateConsumed() throws IOException; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 1bb2401d4a543..e272aef9efa27 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -781,14 +781,14 @@ public void onBuffer(Buffer buffer, int sequenceNumber, int backlog, int subpart firstPriorityEvent = addPriorityBuffer(sequenceBuffer); recycleBuffer = false; } else { - if (inRecovery) { + if (inRecovery) { // review nit: merge if // The upstream has no credit until recovery delivery finishes, so it can // only // send events here, never data buffers. Stash ordinary events so they are // consumed after the recovered buffers; data buffers are a protocol // violation. checkState( - !buffer.isBuffer(), + !buffer.isBuffer(), // review todo: check what events can be sent "Received live data buffer during recovery on channel %s", getChannelInfo()); recoveryEventStash.add(sequenceBuffer); @@ -878,6 +878,7 @@ private void checkAnnouncedOnlyOnce(SequenceBuffer sequenceBuffer) { * before the matching RecoveryCheckpointBarrier sentinel; after recovery, uses the normal * remote-channel barrier sequence tracking and persists overtaken live buffers. */ + // review todo public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { try { List toPersist; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 54b0bd928e417..979a674751698 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -243,6 +243,9 @@ public class SingleInputGate extends IndexedInputGate { */ private final int[] endOfPartitions; + // review: this should be a parameter to inputGate.finishReadRecoveredState() or a new method + // finishFetchState() + // private volatile boolean checkpointingDuringRecoveryEnabled = false; public SingleInputGate( String owningTaskName, diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 7e1e0215f0e20..4557ca59cb8ee 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -984,7 +984,7 @@ private Optional fetchChannelState( SequentialChannelStateReader reader, IndexedInputGate[] inputGates) { try { return reader.readInputData(inputGates, createRecordFilterContext()); - } catch (Throwable t) { + } catch (Throwable t) { // review: don't catch errors asyncExceptionHandler.handleAsyncException( "Unable to set up recovered channel state", t); return Optional.empty(); From 55748845577965ed23e8047e379213550eb178b2 Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Sat, 27 Jun 2026 08:13:16 +0200 Subject: [PATCH 15/16] [FLINK-38544] Fix spotless formatting violations in recovery channel-state tests CI (build 76428) failed the spotless-check on flink-runtime for four recovery test files. Applied spotless:apply (google-java-format); formatting-only changes. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../channel/ChannelIOExecutorDrainSubmissionTest.java | 6 ++---- .../channel/FetchedChannelStateDrainerConcurrencyTest.java | 3 +-- .../partition/consumer/LocalRecoveredInputChannelTest.java | 3 +-- .../partition/consumer/RemoteRecoveredInputChannelTest.java | 3 +-- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java index a143e3735306a..7bbc7e6e8f5d4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelIOExecutorDrainSubmissionTest.java @@ -53,8 +53,7 @@ void testFilterOnSubmitsDrainAfterConversion() throws Exception { CapturingChannel chan = new CapturingChannel(cInfo); List all = new ArrayList<>(); all.add(chan); - FetchedChannelStateDrainer drainer = - new FetchedChannelStateDrainer(state, all); + FetchedChannelStateDrainer drainer = new FetchedChannelStateDrainer(state, all); ExecutorService channelIOExecutor = Executors.newSingleThreadExecutor(); try { @@ -120,8 +119,7 @@ public void onRecoveredStateConsumed() {} List all = new ArrayList<>(); all.add(chan); - FetchedChannelStateDrainer drainer = - new FetchedChannelStateDrainer(state, all); + FetchedChannelStateDrainer drainer = new FetchedChannelStateDrainer(state, all); CountDownLatch handlerCalled = new CountDownLatch(1); AtomicReference captured = new AtomicReference<>(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java index cd845e2fbf2a1..63824a02d8c81 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/FetchedChannelStateDrainerConcurrencyTest.java @@ -78,8 +78,7 @@ void testDrainAndSnapshotConcurrentAtomicity() throws Exception { all.add(chan0); all.add(chan1); - FetchedChannelStateDrainer drainer = - new FetchedChannelStateDrainer(state, all); + FetchedChannelStateDrainer drainer = new FetchedChannelStateDrainer(state, all); ExecutorService io = Executors.newSingleThreadExecutor(); AtomicReference drainError = new AtomicReference<>(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java index 9559ef316c5a1..005c86d3ed011 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannelTest.java @@ -30,8 +30,7 @@ class LocalRecoveredInputChannelTest { @Test void testToInputChannelRequiresEmptyRecoveredBuffers() throws Exception { - SingleInputGate inputGate = - new SingleInputGateBuilder().build(); + SingleInputGate inputGate = new SingleInputGateBuilder().build(); LocalRecoveredInputChannel recoveredChannel = InputChannelBuilder.newBuilder() .setStateWriter(ChannelStateWriter.NO_OP) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java index 339ccced04b75..afa6d6fc6d8b8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannelTest.java @@ -30,8 +30,7 @@ class RemoteRecoveredInputChannelTest { @Test void testToInputChannelRequiresEmptyRecoveredBuffers() throws Exception { - SingleInputGate inputGate = - new SingleInputGateBuilder().build(); + SingleInputGate inputGate = new SingleInputGateBuilder().build(); RemoteRecoveredInputChannel recoveredChannel = InputChannelBuilder.newBuilder() .setStateWriter(ChannelStateWriter.NO_OP) From b22e7b4e33ae114c3bd1e96e90606ffa97bcdddd Mon Sep 17 00:00:00 2001 From: Roman Khachatryan Date: Sat, 27 Jun 2026 15:45:38 +0200 Subject: [PATCH 16/16] [FLINK-38544][checkpoint] Convert recovered gates per-gate during recovery recoverChannelsWithoutCheckpointing deferred requestPartitions/conversion until completeAll(all gates' stateConsumedFutures), so no gate converted until every gate's recovered state had been drained. A selective-reading multi-input operator only drains the *selected* input's end-of-state sentinel, so an unselected gate never drained (it is read only after conversion) while conversion waited for it to drain first -- a circular wait that deadlocked the restore mailbox loop (parked in processMailsWhenDefaultActionUnavailable -> mailbox.take()). This hung StreamTaskSelectiveReadingITCase and other multi-input recovery in CI (build 76435: tests/table/python groups, watchdog-killed). Trigger each gate's requestPartitions(false) off its own getStateConsumedFuture() (restoring the pre-rewrite per-gate behavior from d9fc48), so a drained gate converts immediately and the reader can progress; suspend() remains gated on completeAll(futures). The flag-ON path (recoverChannelsWithCheckpointing) is unaffected -- it is drainer-driven, not gated on the consumer draining. Validated: StreamTaskSelectiveReadingITCase clean over 75 runs under JDK-17 load (was hanging on iteration 1); TaskCheckpointingBehaviourTest, recovered-channel tests and MultipleInputStreamTaskTest green. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../streaming/runtime/tasks/StreamTask.java | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 4557ca59cb8ee..39c68db32b9b5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -961,9 +961,21 @@ private CompletableFuture recoverChannelsWithoutCheckpointing( channelIOExecutor.execute(() -> readInputChannelState(reader, inputGates)); List> futures = new ArrayList<>(); for (InputGate inputGate : inputGates) { - futures.add(inputGate.getStateConsumedFuture()); + CompletableFuture stateConsumed = inputGate.getStateConsumedFuture(); + futures.add(stateConsumed); + // Convert and request partitions for each gate as soon as ITS OWN recovered state is + // drained, independent of the other gates. Deferring conversion until every gate has + // drained deadlocks selective-reading multi-input operators: such an operator only + // drains the selected input's end-of-state sentinel, so an unselected gate would never + // drain (it is read only after conversion) while conversion would wait for it to drain + // first. suspend() is still gated on completeAll(futures) below. + stateConsumed.thenRun( + () -> + mainMailboxExecutor.execute( + () -> inputGate.requestPartitions(false), + "Input gate request partitions")); } - return completeAll(futures).thenRun(() -> requestPartitions(inputGates, false)); + return completeAll(futures); } private void readInputChannelState(